From 0411ea22a96f9c22af30156b45c16ef39ffb520d Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Thu, 15 Dec 2022 09:28:01 +0100 Subject: [MRG] New API for OT solver (with pre-computed ground cost matrix) (#388) * new API for OT solver * use itertools for product of parameters * add tests for result class * add tests for result class * add tests for result class last time? * add sinkhorn * make partial OT bckend compatible * add TV as unbalanced flavor * better tests * make smoth backend compatible and add l2 tregularizatio to solve * add reularizedd unbalanced * add test for more complex attibutes * add test for more complex attibutes * add generic unbalaned solver and implement it for ot.solve * add entropy to possible regularization * star of documentation for ot.solv * weird new pep8 * documenttaion for function ot.solve done * pep8 * Update ot/solvers.py Co-authored-by: Alexandre Gramfort * update release file * Apply suggestions from code review Co-authored-by: Alexandre Gramfort * add test NotImplemented * pep8 * pep8gcmp pep8! * compute kl in backend * debug tensorflow kl backend Co-authored-by: Alexandre Gramfort --- ot/partial.py | 47 ++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 38 insertions(+), 9 deletions(-) (limited to 'ot/partial.py') diff --git a/ot/partial.py b/ot/partial.py index 0a9e450..eae91c4 100755 --- a/ot/partial.py +++ b/ot/partial.py @@ -8,6 +8,8 @@ Partial OT solvers import numpy as np from .lp import emd +from .backend import get_backend +from .utils import list_to_array def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, @@ -114,14 +116,22 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, ot.partial.partial_wasserstein : Partial Wasserstein with fixed mass """ - if np.sum(a) > 1 or np.sum(b) > 1: + a, b, M = list_to_array(a, b, M) + + nx = get_backend(a, b, M) + + if nx.sum(a) > 1 or nx.sum(b) > 1: raise ValueError("Problem infeasible. Check that a and b are in the " "simplex") if reg_m is None: - reg_m = np.max(M) + 1 - if reg_m < -np.max(M): - return np.zeros((len(a), len(b))) + reg_m = float(nx.max(M)) + 1 + if reg_m < -nx.max(M): + return nx.zeros((len(a), len(b)), type_as=M) + + a0, b0, M0 = a, b, M + # convert to humpy + a, b, M = nx.to_numpy(a, b, M) eps = 1e-20 M = np.asarray(M, dtype=np.float64) @@ -149,10 +159,16 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, gamma = np.zeros((len(a), len(b))) gamma[np.ix_(idx_x, idx_y)] = gamma_extended[:-nb_dummies, :-nb_dummies] + # convert back to backend + gamma = nx.from_numpy(gamma, type_as=M0) + if log_emd['warning'] is not None: raise ValueError("Error in the EMD resolution: try to increase the" " number of dummy points") - log_emd['cost'] = np.sum(gamma * M) + log_emd['cost'] = nx.sum(gamma * M0) + log_emd['u'] = nx.from_numpy(log_emd['u'], type_as=a0) + log_emd['v'] = nx.from_numpy(log_emd['v'], type_as=b0) + if log: return gamma, log_emd else: @@ -250,15 +266,23 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): entropic regularization parameter """ + a, b, M = list_to_array(a, b, M) + + nx = get_backend(a, b, M) + if m is None: return partial_wasserstein_lagrange(a, b, M, log=log, **kwargs) elif m < 0: raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") - elif m > np.min((np.sum(a), np.sum(b))): + elif m > nx.min((nx.sum(a), nx.sum(b))): raise ValueError("Problem infeasible. Parameter m should lower or" " equal than min(|a|_1, |b|_1).") + a0, b0, M0 = a, b, M + # convert to humpy + a, b, M = nx.to_numpy(a, b, M) + b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies) a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies) M_extended = np.zeros((len(a_extended), len(b_extended))) @@ -267,15 +291,20 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): gamma, log_emd = emd(a_extended, b_extended, M_extended, log=True, **kwargs) + + gamma = nx.from_numpy(gamma[:len(a), :len(b)], type_as=M) + if log_emd['warning'] is not None: raise ValueError("Error in the EMD resolution: try to increase the" " number of dummy points") - log_emd['partial_w_dist'] = np.sum(M * gamma[:len(a), :len(b)]) + log_emd['partial_w_dist'] = nx.sum(M0 * gamma) + log_emd['u'] = nx.from_numpy(log_emd['u'][:len(a)], type_as=a0) + log_emd['v'] = nx.from_numpy(log_emd['v'][:len(b)], type_as=b0) if log: - return gamma[:len(a), :len(b)], log_emd + return gamma, log_emd else: - return gamma[:len(a), :len(b)] + return gamma def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): -- cgit v1.2.3