summaryrefslogtreecommitdiff
path: root/ot/partial.py
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2022-12-15 09:28:01 +0100
committerGitHub <noreply@github.com>2022-12-15 09:28:01 +0100
commit0411ea22a96f9c22af30156b45c16ef39ffb520d (patch)
tree7c131ad804d5b16a8c362c2fe296350a770400df /ot/partial.py
parent8490196dcc982c492b7565e1ec4de5f75f006acf (diff)
[MRG] New API for OT solver (with pre-computed ground cost matrix) (#388)
* new API for OT solver * use itertools for product of parameters * add tests for result class * add tests for result class * add tests for result class last time? * add sinkhorn * make partial OT bckend compatible * add TV as unbalanced flavor * better tests * make smoth backend compatible and add l2 tregularizatio to solve * add reularizedd unbalanced * add test for more complex attibutes * add test for more complex attibutes * add generic unbalaned solver and implement it for ot.solve * add entropy to possible regularization * star of documentation for ot.solv * weird new pep8 * documenttaion for function ot.solve done * pep8 * Update ot/solvers.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * update release file * Apply suggestions from code review Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * add test NotImplemented * pep8 * pep8gcmp pep8! * compute kl in backend * debug tensorflow kl backend Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'ot/partial.py')
-rwxr-xr-xot/partial.py47
1 files changed, 38 insertions, 9 deletions
diff --git a/ot/partial.py b/ot/partial.py
index 0a9e450..eae91c4 100755
--- a/ot/partial.py
+++ b/ot/partial.py
@@ -8,6 +8,8 @@ Partial OT solvers
import numpy as np
from .lp import emd
+from .backend import get_backend
+from .utils import list_to_array
def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
@@ -114,14 +116,22 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
ot.partial.partial_wasserstein : Partial Wasserstein with fixed mass
"""
- if np.sum(a) > 1 or np.sum(b) > 1:
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(a, b, M)
+
+ if nx.sum(a) > 1 or nx.sum(b) > 1:
raise ValueError("Problem infeasible. Check that a and b are in the "
"simplex")
if reg_m is None:
- reg_m = np.max(M) + 1
- if reg_m < -np.max(M):
- return np.zeros((len(a), len(b)))
+ reg_m = float(nx.max(M)) + 1
+ if reg_m < -nx.max(M):
+ return nx.zeros((len(a), len(b)), type_as=M)
+
+ a0, b0, M0 = a, b, M
+ # convert to humpy
+ a, b, M = nx.to_numpy(a, b, M)
eps = 1e-20
M = np.asarray(M, dtype=np.float64)
@@ -149,10 +159,16 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
gamma = np.zeros((len(a), len(b)))
gamma[np.ix_(idx_x, idx_y)] = gamma_extended[:-nb_dummies, :-nb_dummies]
+ # convert back to backend
+ gamma = nx.from_numpy(gamma, type_as=M0)
+
if log_emd['warning'] is not None:
raise ValueError("Error in the EMD resolution: try to increase the"
" number of dummy points")
- log_emd['cost'] = np.sum(gamma * M)
+ log_emd['cost'] = nx.sum(gamma * M0)
+ log_emd['u'] = nx.from_numpy(log_emd['u'], type_as=a0)
+ log_emd['v'] = nx.from_numpy(log_emd['v'], type_as=b0)
+
if log:
return gamma, log_emd
else:
@@ -250,15 +266,23 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
entropic regularization parameter
"""
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(a, b, M)
+
if m is None:
return partial_wasserstein_lagrange(a, b, M, log=log, **kwargs)
elif m < 0:
raise ValueError("Problem infeasible. Parameter m should be greater"
" than 0.")
- elif m > np.min((np.sum(a), np.sum(b))):
+ elif m > nx.min((nx.sum(a), nx.sum(b))):
raise ValueError("Problem infeasible. Parameter m should lower or"
" equal than min(|a|_1, |b|_1).")
+ a0, b0, M0 = a, b, M
+ # convert to humpy
+ a, b, M = nx.to_numpy(a, b, M)
+
b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies)
a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies)
M_extended = np.zeros((len(a_extended), len(b_extended)))
@@ -267,15 +291,20 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
gamma, log_emd = emd(a_extended, b_extended, M_extended, log=True,
**kwargs)
+
+ gamma = nx.from_numpy(gamma[:len(a), :len(b)], type_as=M)
+
if log_emd['warning'] is not None:
raise ValueError("Error in the EMD resolution: try to increase the"
" number of dummy points")
- log_emd['partial_w_dist'] = np.sum(M * gamma[:len(a), :len(b)])
+ log_emd['partial_w_dist'] = nx.sum(M0 * gamma)
+ log_emd['u'] = nx.from_numpy(log_emd['u'][:len(a)], type_as=a0)
+ log_emd['v'] = nx.from_numpy(log_emd['v'][:len(b)], type_as=b0)
if log:
- return gamma[:len(a), :len(b)], log_emd
+ return gamma, log_emd
else:
- return gamma[:len(a), :len(b)]
+ return gamma
def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):