From 0411ea22a96f9c22af30156b45c16ef39ffb520d Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Thu, 15 Dec 2022 09:28:01 +0100 Subject: [MRG] New API for OT solver (with pre-computed ground cost matrix) (#388) * new API for OT solver * use itertools for product of parameters * add tests for result class * add tests for result class * add tests for result class last time? * add sinkhorn * make partial OT bckend compatible * add TV as unbalanced flavor * better tests * make smoth backend compatible and add l2 tregularizatio to solve * add reularizedd unbalanced * add test for more complex attibutes * add test for more complex attibutes * add generic unbalaned solver and implement it for ot.solve * add entropy to possible regularization * star of documentation for ot.solv * weird new pep8 * documenttaion for function ot.solve done * pep8 * Update ot/solvers.py Co-authored-by: Alexandre Gramfort * update release file * Apply suggestions from code review Co-authored-by: Alexandre Gramfort * add test NotImplemented * pep8 * pep8gcmp pep8! * compute kl in backend * debug tensorflow kl backend Co-authored-by: Alexandre Gramfort --- ot/partial.py | 47 ++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 38 insertions(+), 9 deletions(-) (limited to 'ot/partial.py') diff --git a/ot/partial.py b/ot/partial.py index 0a9e450..eae91c4 100755 --- a/ot/partial.py +++ b/ot/partial.py @@ -8,6 +8,8 @@ Partial OT solvers import numpy as np from .lp import emd +from .backend import get_backend +from .utils import list_to_array def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, @@ -114,14 +116,22 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, ot.partial.partial_wasserstein : Partial Wasserstein with fixed mass """ - if np.sum(a) > 1 or np.sum(b) > 1: + a, b, M = list_to_array(a, b, M) + + nx = get_backend(a, b, M) + + if nx.sum(a) > 1 or nx.sum(b) > 1: raise ValueError("Problem infeasible. Check that a and b are in the " "simplex") if reg_m is None: - reg_m = np.max(M) + 1 - if reg_m < -np.max(M): - return np.zeros((len(a), len(b))) + reg_m = float(nx.max(M)) + 1 + if reg_m < -nx.max(M): + return nx.zeros((len(a), len(b)), type_as=M) + + a0, b0, M0 = a, b, M + # convert to humpy + a, b, M = nx.to_numpy(a, b, M) eps = 1e-20 M = np.asarray(M, dtype=np.float64) @@ -149,10 +159,16 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, gamma = np.zeros((len(a), len(b))) gamma[np.ix_(idx_x, idx_y)] = gamma_extended[:-nb_dummies, :-nb_dummies] + # convert back to backend + gamma = nx.from_numpy(gamma, type_as=M0) + if log_emd['warning'] is not None: raise ValueError("Error in the EMD resolution: try to increase the" " number of dummy points") - log_emd['cost'] = np.sum(gamma * M) + log_emd['cost'] = nx.sum(gamma * M0) + log_emd['u'] = nx.from_numpy(log_emd['u'], type_as=a0) + log_emd['v'] = nx.from_numpy(log_emd['v'], type_as=b0) + if log: return gamma, log_emd else: @@ -250,15 +266,23 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): entropic regularization parameter """ + a, b, M = list_to_array(a, b, M) + + nx = get_backend(a, b, M) + if m is None: return partial_wasserstein_lagrange(a, b, M, log=log, **kwargs) elif m < 0: raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") - elif m > np.min((np.sum(a), np.sum(b))): + elif m > nx.min((nx.sum(a), nx.sum(b))): raise ValueError("Problem infeasible. Parameter m should lower or" " equal than min(|a|_1, |b|_1).") + a0, b0, M0 = a, b, M + # convert to humpy + a, b, M = nx.to_numpy(a, b, M) + b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies) a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies) M_extended = np.zeros((len(a_extended), len(b_extended))) @@ -267,15 +291,20 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): gamma, log_emd = emd(a_extended, b_extended, M_extended, log=True, **kwargs) + + gamma = nx.from_numpy(gamma[:len(a), :len(b)], type_as=M) + if log_emd['warning'] is not None: raise ValueError("Error in the EMD resolution: try to increase the" " number of dummy points") - log_emd['partial_w_dist'] = np.sum(M * gamma[:len(a), :len(b)]) + log_emd['partial_w_dist'] = nx.sum(M0 * gamma) + log_emd['u'] = nx.from_numpy(log_emd['u'][:len(a)], type_as=a0) + log_emd['v'] = nx.from_numpy(log_emd['v'][:len(b)], type_as=b0) if log: - return gamma[:len(a), :len(b)], log_emd + return gamma, log_emd else: - return gamma[:len(a), :len(b)] + return gamma def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): -- cgit v1.2.3 From b9ed7b1650475420cc5bbec6c31476cc098790d5 Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Tue, 21 Mar 2023 15:18:09 +0100 Subject: [MRG] Make partial_wasserstein, partial_wasserstein2 and entropic_partial_wasserstein work with backend (#449) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add test of partial_wasserstein with torch tensors * WIP: differentiable ot.partial.partial_wasserstein * change test of torch partial * make partial_wasserstein2 work with torch * test backward through ot.partial.partial_wasserstein2 * add test of entropic_partial_wasserstein with torch tensors * make entropic_partial_wasserstein work with torch tensors * add test of backward through entropic_partial_wasserstein * rm unused import * test partial_wasserstein with all backends * tests of partial fcts: check if torch is available * partial: check if marginals are empty arrays * add tests when marginals are empty arrays and/or m=None * add PR to RELEASES.md --------- Co-authored-by: Antoine Collas <22830806+antoinecollas@users.noreply.github.com> Co-authored-by: Rémi Flamary --- RELEASES.md | 2 +- ot/partial.py | 95 +++++++++++++++++++++++++------------------ test/test_partial.py | 113 +++++++++++++++++++++++++++++++++++++++++---------- 3 files changed, 148 insertions(+), 62 deletions(-) (limited to 'ot/partial.py') diff --git a/RELEASES.md b/RELEASES.md index 5d966c2..e4c6e15 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -14,7 +14,7 @@ - Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376) - Added Free Support Sinkhorn Barycenter + example (PR #387) - New API for OT solver using function `ot.solve` (PR #388) -- Backend version of `ot.partial` and `ot.smooth` (PR #388) +- Backend version of `ot.partial` and `ot.smooth` (PR #388 and #449) - Added argument for warmstart of dual potentials in Sinkhorn-based methods in `ot.bregman` (PR #437) - Add parameters method in `ot.da.SinkhornTransport` (PR #440) - `ot.dr` now uses the new Pymanopt API and POT is compatible with current Pymanopt (PR #443) diff --git a/ot/partial.py b/ot/partial.py index eae91c4..bf4119d 100755 --- a/ot/partial.py +++ b/ot/partial.py @@ -120,7 +120,7 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, nx = get_backend(a, b, M) - if nx.sum(a) > 1 or nx.sum(b) > 1: + if nx.sum(a) > 1 + 1e-15 or nx.sum(b) > 1 + 1e-15: # 1e-15 for numerical errors raise ValueError("Problem infeasible. Check that a and b are in the " "simplex") @@ -270,36 +270,43 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): nx = get_backend(a, b, M) + dim_a, dim_b = M.shape + if len(a) == 0: + a = nx.ones(dim_a, type_as=a) / dim_a + if len(b) == 0: + b = nx.ones(dim_b, type_as=b) / dim_b + if m is None: return partial_wasserstein_lagrange(a, b, M, log=log, **kwargs) elif m < 0: raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") - elif m > nx.min((nx.sum(a), nx.sum(b))): + elif m > nx.min(nx.stack((nx.sum(a), nx.sum(b)))): raise ValueError("Problem infeasible. Parameter m should lower or" " equal than min(|a|_1, |b|_1).") - a0, b0, M0 = a, b, M - # convert to humpy - a, b, M = nx.to_numpy(a, b, M) - - b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies) - a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies) - M_extended = np.zeros((len(a_extended), len(b_extended))) - M_extended[-nb_dummies:, -nb_dummies:] = np.max(M) * 2 - M_extended[:len(a), :len(b)] = M + b_extension = nx.ones(nb_dummies, type_as=b) * (nx.sum(a) - m) / nb_dummies + b_extended = nx.concatenate((b, b_extension)) + a_extension = nx.ones(nb_dummies, type_as=a) * (nx.sum(b) - m) / nb_dummies + a_extended = nx.concatenate((a, a_extension)) + M_extension = nx.ones((nb_dummies, nb_dummies), type_as=M) * nx.max(M) * 2 + M_extended = nx.concatenate( + (nx.concatenate((M, nx.zeros((M.shape[0], M_extension.shape[1]))), axis=1), + nx.concatenate((nx.zeros((M_extension.shape[0], M.shape[1])), M_extension), axis=1)), + axis=0 + ) gamma, log_emd = emd(a_extended, b_extended, M_extended, log=True, **kwargs) - gamma = nx.from_numpy(gamma[:len(a), :len(b)], type_as=M) + gamma = gamma[:len(a), :len(b)] if log_emd['warning'] is not None: raise ValueError("Error in the EMD resolution: try to increase the" " number of dummy points") - log_emd['partial_w_dist'] = nx.sum(M0 * gamma) - log_emd['u'] = nx.from_numpy(log_emd['u'][:len(a)], type_as=a0) - log_emd['v'] = nx.from_numpy(log_emd['v'][:len(b)], type_as=b0) + log_emd['partial_w_dist'] = nx.sum(M * gamma) + log_emd['u'] = log_emd['u'][:len(a)] + log_emd['v'] = log_emd['v'][:len(b)] if log: return gamma, log_emd @@ -389,14 +396,18 @@ def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): NeurIPS. """ + a, b, M = list_to_array(a, b, M) + + nx = get_backend(a, b, M) + partial_gw, log_w = partial_wasserstein(a, b, M, m, nb_dummies, log=True, **kwargs) log_w['T'] = partial_gw if log: - return np.sum(partial_gw * M), log_w + return nx.sum(partial_gw * M), log_w else: - return np.sum(partial_gw * M) + return nx.sum(partial_gw * M) def gwgrad_partial(C1, C2, T): @@ -838,60 +849,64 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, ot.partial.partial_wasserstein: exact Partial Wasserstein """ - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + a, b, M = list_to_array(a, b, M) + + nx = get_backend(a, b, M) dim_a, dim_b = M.shape - dx = np.ones(dim_a, dtype=np.float64) - dy = np.ones(dim_b, dtype=np.float64) + dx = nx.ones(dim_a, type_as=a) + dy = nx.ones(dim_b, type_as=b) if len(a) == 0: - a = np.ones(dim_a, dtype=np.float64) / dim_a + a = nx.ones(dim_a, type_as=a) / dim_a if len(b) == 0: - b = np.ones(dim_b, dtype=np.float64) / dim_b + b = nx.ones(dim_b, type_as=b) / dim_b if m is None: - m = np.min((np.sum(a), np.sum(b))) * 1.0 + m = nx.min(nx.stack((nx.sum(a), nx.sum(b)))) * 1.0 if m < 0: raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") - if m > np.min((np.sum(a), np.sum(b))): + if m > nx.min(nx.stack((nx.sum(a), nx.sum(b)))): raise ValueError("Problem infeasible. Parameter m should lower or" " equal than min(|a|_1, |b|_1).") log_e = {'err': []} - # Next 3 lines equivalent to K=np.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) - np.multiply(K, m / np.sum(K), out=K) + if type(a) == type(b) == type(M) == np.ndarray: + # Next 3 lines equivalent to K=nx.exp(-M/reg), but faster to compute + K = np.empty(M.shape, dtype=M.dtype) + np.divide(M, -reg, out=K) + np.exp(K, out=K) + np.multiply(K, m / np.sum(K), out=K) + else: + K = nx.exp(-M / reg) + K = K * m / nx.sum(K) err, cpt = 1, 0 - q1 = np.ones(K.shape) - q2 = np.ones(K.shape) - q3 = np.ones(K.shape) + q1 = nx.ones(K.shape, type_as=K) + q2 = nx.ones(K.shape, type_as=K) + q3 = nx.ones(K.shape, type_as=K) while (err > stopThr and cpt < numItermax): Kprev = K K = K * q1 - K1 = np.dot(np.diag(np.minimum(a / np.sum(K, axis=1), dx)), K) + K1 = nx.dot(nx.diag(nx.minimum(a / nx.sum(K, axis=1), dx)), K) q1 = q1 * Kprev / K1 K1prev = K1 K1 = K1 * q2 - K2 = np.dot(K1, np.diag(np.minimum(b / np.sum(K1, axis=0), dy))) + K2 = nx.dot(K1, nx.diag(nx.minimum(b / nx.sum(K1, axis=0), dy))) q2 = q2 * K1prev / K2 K2prev = K2 K2 = K2 * q3 - K = K2 * (m / np.sum(K2)) + K = K2 * (m / nx.sum(K2)) q3 = q3 * K2prev / K - if np.any(np.isnan(K)) or np.any(np.isinf(K)): + if nx.any(nx.isnan(K)) or nx.any(nx.isinf(K)): print('Warning: numerical errors at iteration', cpt) break if cpt % 10 == 0: - err = np.linalg.norm(Kprev - K) + err = nx.norm(Kprev - K) if log: log_e['err'].append(err) if verbose: @@ -901,7 +916,7 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, print('{:5d}|{:8e}|'.format(cpt, err)) cpt = cpt + 1 - log_e['partial_w_dist'] = np.sum(M * K) + log_e['partial_w_dist'] = nx.sum(M * K) if log: return K, log_e else: diff --git a/test/test_partial.py b/test/test_partial.py index ae4a1ab..86f9e62 100755 --- a/test/test_partial.py +++ b/test/test_partial.py @@ -8,6 +8,7 @@ import numpy as np import scipy as sp import ot +from ot.backend import to_numpy, torch import pytest @@ -82,7 +83,7 @@ def test_partial_wasserstein_lagrange(): w0, log0 = ot.partial.partial_wasserstein_lagrange(p, q, M, 100, log=True) -def test_partial_wasserstein(): +def test_partial_wasserstein(nx): n_samples = 20 # nb samples (gaussian) n_noise = 20 # nb of samples (noise) @@ -102,25 +103,20 @@ def test_partial_wasserstein(): m = 0.5 + p, q, M = nx.from_numpy(p, q, M) + w0, log0 = ot.partial.partial_wasserstein(p, q, M, m=m, log=True) - w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m, - log=True, verbose=True) + w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m, log=True, verbose=True) # check constraints - np.testing.assert_equal( - w0.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein - np.testing.assert_equal( - w0.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein - np.testing.assert_equal( - w.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein - np.testing.assert_equal( - w.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein + np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p)) + np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q)) + np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p)) + np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q)) # check transported mass - np.testing.assert_allclose( - np.sum(w0), m, atol=1e-04) - np.testing.assert_allclose( - np.sum(w), m, atol=1e-04) + np.testing.assert_allclose(np.sum(to_numpy(w0)), m, atol=1e-04) + np.testing.assert_allclose(np.sum(to_numpy(w)), m, atol=1e-04) w0, log0 = ot.partial.partial_wasserstein2(p, q, M, m=m, log=True) w0_val = ot.partial.partial_wasserstein2(p, q, M, m=m, log=False) @@ -130,12 +126,87 @@ def test_partial_wasserstein(): np.testing.assert_allclose(w0, w0_val, atol=1e-1, rtol=1e-1) # check constraints - np.testing.assert_equal( - G.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein - np.testing.assert_equal( - G.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein - np.testing.assert_allclose( - np.sum(G), m, atol=1e-04) + np.testing.assert_equal(to_numpy(nx.sum(G, axis=1) - p) <= 1e-5, [True] * len(p)) + np.testing.assert_equal(to_numpy(nx.sum(G, axis=0) - q) <= 1e-5, [True] * len(q)) + np.testing.assert_allclose(np.sum(to_numpy(G)), m, atol=1e-04) + + empty_array = nx.zeros(0, type_as=M) + w = ot.partial.partial_wasserstein(empty_array, empty_array, M=M, m=None) + + # check constraints + np.testing.assert_equal(to_numpy(nx.sum(w, axis=1) - p) <= 1e-5, [True] * len(p)) + np.testing.assert_equal(to_numpy(nx.sum(w, axis=0) - q) <= 1e-5, [True] * len(q)) + np.testing.assert_equal(to_numpy(nx.sum(w, axis=1) - p) <= 1e-5, [True] * len(p)) + np.testing.assert_equal(to_numpy(nx.sum(w, axis=0) - q) <= 1e-5, [True] * len(q)) + + # check transported mass + np.testing.assert_allclose(np.sum(to_numpy(w)), 1, atol=1e-04) + + w0 = ot.partial.entropic_partial_wasserstein(empty_array, empty_array, M=M, reg=10, m=None) + + # check constraints + np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p)) + np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q)) + np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p)) + np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q)) + + # check transported mass + np.testing.assert_allclose(np.sum(to_numpy(w0)), 1, atol=1e-04) + + +def test_partial_wasserstein2_gradient(): + if torch: + n_samples = 40 + + mu = np.array([0, 0]) + cov = np.array([[1, 0], [0, 2]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov) + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov) + + M = torch.tensor(ot.dist(xs, xt), requires_grad=True, dtype=torch.float64) + + p = torch.tensor(ot.unif(n_samples), dtype=torch.float64) + q = torch.tensor(ot.unif(n_samples), dtype=torch.float64) + + m = 0.5 + + w, log = ot.partial.partial_wasserstein2(p, q, M, m=m, log=True) + + w.backward() + + assert M.grad is not None + assert M.grad.shape == M.shape + + +def test_entropic_partial_wasserstein_gradient(): + if torch: + n_samples = 40 + + mu = np.array([0, 0]) + cov = np.array([[1, 0], [0, 2]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov) + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov) + + M = torch.tensor(ot.dist(xs, xt), requires_grad=True, dtype=torch.float64) + + p = torch.tensor(ot.unif(n_samples), requires_grad=True, dtype=torch.float64) + q = torch.tensor(ot.unif(n_samples), requires_grad=True, dtype=torch.float64) + + m = 0.5 + reg = 1 + + _, log = ot.partial.entropic_partial_wasserstein(p, q, M, m=m, reg=reg, log=True) + + log['partial_w_dist'].backward() + + assert M.grad is not None + assert p.grad is not None + assert q.grad is not None + assert M.grad.shape == M.shape + assert p.grad.shape == p.shape + assert q.grad.shape == q.shape def test_partial_gromov_wasserstein(): -- cgit v1.2.3