From b9ed7b1650475420cc5bbec6c31476cc098790d5 Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Tue, 21 Mar 2023 15:18:09 +0100 Subject: [MRG] Make partial_wasserstein, partial_wasserstein2 and entropic_partial_wasserstein work with backend (#449) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add test of partial_wasserstein with torch tensors * WIP: differentiable ot.partial.partial_wasserstein * change test of torch partial * make partial_wasserstein2 work with torch * test backward through ot.partial.partial_wasserstein2 * add test of entropic_partial_wasserstein with torch tensors * make entropic_partial_wasserstein work with torch tensors * add test of backward through entropic_partial_wasserstein * rm unused import * test partial_wasserstein with all backends * tests of partial fcts: check if torch is available * partial: check if marginals are empty arrays * add tests when marginals are empty arrays and/or m=None * add PR to RELEASES.md --------- Co-authored-by: Antoine Collas <22830806+antoinecollas@users.noreply.github.com> Co-authored-by: RĂ©mi Flamary --- ot/partial.py | 95 ++++++++++++++++++++++++++++++++++------------------------- 1 file changed, 55 insertions(+), 40 deletions(-) (limited to 'ot/partial.py') diff --git a/ot/partial.py b/ot/partial.py index eae91c4..bf4119d 100755 --- a/ot/partial.py +++ b/ot/partial.py @@ -120,7 +120,7 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, nx = get_backend(a, b, M) - if nx.sum(a) > 1 or nx.sum(b) > 1: + if nx.sum(a) > 1 + 1e-15 or nx.sum(b) > 1 + 1e-15: # 1e-15 for numerical errors raise ValueError("Problem infeasible. Check that a and b are in the " "simplex") @@ -270,36 +270,43 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): nx = get_backend(a, b, M) + dim_a, dim_b = M.shape + if len(a) == 0: + a = nx.ones(dim_a, type_as=a) / dim_a + if len(b) == 0: + b = nx.ones(dim_b, type_as=b) / dim_b + if m is None: return partial_wasserstein_lagrange(a, b, M, log=log, **kwargs) elif m < 0: raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") - elif m > nx.min((nx.sum(a), nx.sum(b))): + elif m > nx.min(nx.stack((nx.sum(a), nx.sum(b)))): raise ValueError("Problem infeasible. Parameter m should lower or" " equal than min(|a|_1, |b|_1).") - a0, b0, M0 = a, b, M - # convert to humpy - a, b, M = nx.to_numpy(a, b, M) - - b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies) - a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies) - M_extended = np.zeros((len(a_extended), len(b_extended))) - M_extended[-nb_dummies:, -nb_dummies:] = np.max(M) * 2 - M_extended[:len(a), :len(b)] = M + b_extension = nx.ones(nb_dummies, type_as=b) * (nx.sum(a) - m) / nb_dummies + b_extended = nx.concatenate((b, b_extension)) + a_extension = nx.ones(nb_dummies, type_as=a) * (nx.sum(b) - m) / nb_dummies + a_extended = nx.concatenate((a, a_extension)) + M_extension = nx.ones((nb_dummies, nb_dummies), type_as=M) * nx.max(M) * 2 + M_extended = nx.concatenate( + (nx.concatenate((M, nx.zeros((M.shape[0], M_extension.shape[1]))), axis=1), + nx.concatenate((nx.zeros((M_extension.shape[0], M.shape[1])), M_extension), axis=1)), + axis=0 + ) gamma, log_emd = emd(a_extended, b_extended, M_extended, log=True, **kwargs) - gamma = nx.from_numpy(gamma[:len(a), :len(b)], type_as=M) + gamma = gamma[:len(a), :len(b)] if log_emd['warning'] is not None: raise ValueError("Error in the EMD resolution: try to increase the" " number of dummy points") - log_emd['partial_w_dist'] = nx.sum(M0 * gamma) - log_emd['u'] = nx.from_numpy(log_emd['u'][:len(a)], type_as=a0) - log_emd['v'] = nx.from_numpy(log_emd['v'][:len(b)], type_as=b0) + log_emd['partial_w_dist'] = nx.sum(M * gamma) + log_emd['u'] = log_emd['u'][:len(a)] + log_emd['v'] = log_emd['v'][:len(b)] if log: return gamma, log_emd @@ -389,14 +396,18 @@ def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): NeurIPS. """ + a, b, M = list_to_array(a, b, M) + + nx = get_backend(a, b, M) + partial_gw, log_w = partial_wasserstein(a, b, M, m, nb_dummies, log=True, **kwargs) log_w['T'] = partial_gw if log: - return np.sum(partial_gw * M), log_w + return nx.sum(partial_gw * M), log_w else: - return np.sum(partial_gw * M) + return nx.sum(partial_gw * M) def gwgrad_partial(C1, C2, T): @@ -838,60 +849,64 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, ot.partial.partial_wasserstein: exact Partial Wasserstein """ - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + a, b, M = list_to_array(a, b, M) + + nx = get_backend(a, b, M) dim_a, dim_b = M.shape - dx = np.ones(dim_a, dtype=np.float64) - dy = np.ones(dim_b, dtype=np.float64) + dx = nx.ones(dim_a, type_as=a) + dy = nx.ones(dim_b, type_as=b) if len(a) == 0: - a = np.ones(dim_a, dtype=np.float64) / dim_a + a = nx.ones(dim_a, type_as=a) / dim_a if len(b) == 0: - b = np.ones(dim_b, dtype=np.float64) / dim_b + b = nx.ones(dim_b, type_as=b) / dim_b if m is None: - m = np.min((np.sum(a), np.sum(b))) * 1.0 + m = nx.min(nx.stack((nx.sum(a), nx.sum(b)))) * 1.0 if m < 0: raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") - if m > np.min((np.sum(a), np.sum(b))): + if m > nx.min(nx.stack((nx.sum(a), nx.sum(b)))): raise ValueError("Problem infeasible. Parameter m should lower or" " equal than min(|a|_1, |b|_1).") log_e = {'err': []} - # Next 3 lines equivalent to K=np.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) - np.multiply(K, m / np.sum(K), out=K) + if type(a) == type(b) == type(M) == np.ndarray: + # Next 3 lines equivalent to K=nx.exp(-M/reg), but faster to compute + K = np.empty(M.shape, dtype=M.dtype) + np.divide(M, -reg, out=K) + np.exp(K, out=K) + np.multiply(K, m / np.sum(K), out=K) + else: + K = nx.exp(-M / reg) + K = K * m / nx.sum(K) err, cpt = 1, 0 - q1 = np.ones(K.shape) - q2 = np.ones(K.shape) - q3 = np.ones(K.shape) + q1 = nx.ones(K.shape, type_as=K) + q2 = nx.ones(K.shape, type_as=K) + q3 = nx.ones(K.shape, type_as=K) while (err > stopThr and cpt < numItermax): Kprev = K K = K * q1 - K1 = np.dot(np.diag(np.minimum(a / np.sum(K, axis=1), dx)), K) + K1 = nx.dot(nx.diag(nx.minimum(a / nx.sum(K, axis=1), dx)), K) q1 = q1 * Kprev / K1 K1prev = K1 K1 = K1 * q2 - K2 = np.dot(K1, np.diag(np.minimum(b / np.sum(K1, axis=0), dy))) + K2 = nx.dot(K1, nx.diag(nx.minimum(b / nx.sum(K1, axis=0), dy))) q2 = q2 * K1prev / K2 K2prev = K2 K2 = K2 * q3 - K = K2 * (m / np.sum(K2)) + K = K2 * (m / nx.sum(K2)) q3 = q3 * K2prev / K - if np.any(np.isnan(K)) or np.any(np.isinf(K)): + if nx.any(nx.isnan(K)) or nx.any(nx.isinf(K)): print('Warning: numerical errors at iteration', cpt) break if cpt % 10 == 0: - err = np.linalg.norm(Kprev - K) + err = nx.norm(Kprev - K) if log: log_e['err'].append(err) if verbose: @@ -901,7 +916,7 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, print('{:5d}|{:8e}|'.format(cpt, err)) cpt = cpt + 1 - log_e['partial_w_dist'] = np.sum(M * K) + log_e['partial_w_dist'] = nx.sum(M * K) if log: return K, log_e else: -- cgit v1.2.3