diff options
author | Laetitia Chapel <laetitia.chapel@univ-ubs.fr> | 2020-04-16 15:52:00 +0200 |
---|---|---|
committer | Laetitia Chapel <laetitia.chapel@univ-ubs.fr> | 2020-04-16 15:52:00 +0200 |
commit | ef7c11a5df3cf6c82864472f0cfa65d6b2036f2f (patch) | |
tree | 1b9cc1d6c281a1eca885212aa3857d48ded4d695 /ot | |
parent | 18b64556aaa477b5499dc05110c96d32b04147ff (diff) |
partial with python 3.8
Diffstat (limited to 'ot')
-rwxr-xr-x | ot/partial.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/ot/partial.py b/ot/partial.py index 8698d9d..726a590 100755 --- a/ot/partial.py +++ b/ot/partial.py @@ -232,7 +232,7 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies) a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies) - M_extended = np.ones((len(a_extended), len(b_extended))) * 0 + M_extended = np.zeros((len(a_extended), len(b_extended))) M_extended[-1, -1] = np.max(M) * 1e5 M_extended[:len(a), :len(b)] = M @@ -510,9 +510,9 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, Gprev = G0 M = gwgrad_partial(C1, C2, G0) - M[M < eps] = np.quantile(M[M > eps], thres) + M[M < eps] = np.quantile(M, thres) - M_emd = np.ones(dim_G_extended) * np.max(M) * 1e2 + M_emd = np.zeros(dim_G_extended) M_emd[:len(p), :len(q)] = M M_emd[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e5 M_emd = np.asarray(M_emd, dtype=np.float64) @@ -729,8 +729,8 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, M = np.asarray(M, dtype=np.float64) dim_a, dim_b = M.shape - dx = np.ones(dim_a) - dy = np.ones(dim_b) + dx = np.ones(dim_a, dtype=np.float64) + dy = np.ones(dim_b, dtype=np.float64) if len(a) == 0: a = np.ones(dim_a, dtype=np.float64) / dim_a @@ -738,7 +738,7 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, b = np.ones(dim_b, dtype=np.float64) / dim_b if m is None: - m = np.min((np.sum(a), np.sum(b))) + m = np.min((np.sum(a), np.sum(b))) * 1.0 if m < 0: raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") |