summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAntoine Collas <contact@antoinecollas.fr>2023-03-21 15:18:09 +0100
committerGitHub <noreply@github.com>2023-03-21 15:18:09 +0100
commitb9ed7b1650475420cc5bbec6c31476cc098790d5 (patch)
treef6624d9509288466ac3e9c0cd475b4a984d72ce4
parentc48cd76235569ada98af6b1bba589510a2818906 (diff)
[MRG] Make partial_wasserstein, partial_wasserstein2 and entropic_partial_wasserstein work with backend (#449)
* add test of partial_wasserstein with torch tensors * WIP: differentiable ot.partial.partial_wasserstein * change test of torch partial * make partial_wasserstein2 work with torch * test backward through ot.partial.partial_wasserstein2 * add test of entropic_partial_wasserstein with torch tensors * make entropic_partial_wasserstein work with torch tensors * add test of backward through entropic_partial_wasserstein * rm unused import * test partial_wasserstein with all backends * tests of partial fcts: check if torch is available * partial: check if marginals are empty arrays * add tests when marginals are empty arrays and/or m=None * add PR to RELEASES.md --------- Co-authored-by: Antoine Collas <22830806+antoinecollas@users.noreply.github.com> Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
-rw-r--r--RELEASES.md2
-rwxr-xr-xot/partial.py95
-rwxr-xr-xtest/test_partial.py113
3 files changed, 148 insertions, 62 deletions
diff --git a/RELEASES.md b/RELEASES.md
index 5d966c2..e4c6e15 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -14,7 +14,7 @@
- Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376)
- Added Free Support Sinkhorn Barycenter + example (PR #387)
- New API for OT solver using function `ot.solve` (PR #388)
-- Backend version of `ot.partial` and `ot.smooth` (PR #388)
+- Backend version of `ot.partial` and `ot.smooth` (PR #388 and #449)
- Added argument for warmstart of dual potentials in Sinkhorn-based methods in `ot.bregman` (PR #437)
- Add parameters method in `ot.da.SinkhornTransport` (PR #440)
- `ot.dr` now uses the new Pymanopt API and POT is compatible with current Pymanopt (PR #443)
diff --git a/ot/partial.py b/ot/partial.py
index eae91c4..bf4119d 100755
--- a/ot/partial.py
+++ b/ot/partial.py
@@ -120,7 +120,7 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
nx = get_backend(a, b, M)
- if nx.sum(a) > 1 or nx.sum(b) > 1:
+ if nx.sum(a) > 1 + 1e-15 or nx.sum(b) > 1 + 1e-15: # 1e-15 for numerical errors
raise ValueError("Problem infeasible. Check that a and b are in the "
"simplex")
@@ -270,36 +270,43 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
nx = get_backend(a, b, M)
+ dim_a, dim_b = M.shape
+ if len(a) == 0:
+ a = nx.ones(dim_a, type_as=a) / dim_a
+ if len(b) == 0:
+ b = nx.ones(dim_b, type_as=b) / dim_b
+
if m is None:
return partial_wasserstein_lagrange(a, b, M, log=log, **kwargs)
elif m < 0:
raise ValueError("Problem infeasible. Parameter m should be greater"
" than 0.")
- elif m > nx.min((nx.sum(a), nx.sum(b))):
+ elif m > nx.min(nx.stack((nx.sum(a), nx.sum(b)))):
raise ValueError("Problem infeasible. Parameter m should lower or"
" equal than min(|a|_1, |b|_1).")
- a0, b0, M0 = a, b, M
- # convert to humpy
- a, b, M = nx.to_numpy(a, b, M)
-
- b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies)
- a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies)
- M_extended = np.zeros((len(a_extended), len(b_extended)))
- M_extended[-nb_dummies:, -nb_dummies:] = np.max(M) * 2
- M_extended[:len(a), :len(b)] = M
+ b_extension = nx.ones(nb_dummies, type_as=b) * (nx.sum(a) - m) / nb_dummies
+ b_extended = nx.concatenate((b, b_extension))
+ a_extension = nx.ones(nb_dummies, type_as=a) * (nx.sum(b) - m) / nb_dummies
+ a_extended = nx.concatenate((a, a_extension))
+ M_extension = nx.ones((nb_dummies, nb_dummies), type_as=M) * nx.max(M) * 2
+ M_extended = nx.concatenate(
+ (nx.concatenate((M, nx.zeros((M.shape[0], M_extension.shape[1]))), axis=1),
+ nx.concatenate((nx.zeros((M_extension.shape[0], M.shape[1])), M_extension), axis=1)),
+ axis=0
+ )
gamma, log_emd = emd(a_extended, b_extended, M_extended, log=True,
**kwargs)
- gamma = nx.from_numpy(gamma[:len(a), :len(b)], type_as=M)
+ gamma = gamma[:len(a), :len(b)]
if log_emd['warning'] is not None:
raise ValueError("Error in the EMD resolution: try to increase the"
" number of dummy points")
- log_emd['partial_w_dist'] = nx.sum(M0 * gamma)
- log_emd['u'] = nx.from_numpy(log_emd['u'][:len(a)], type_as=a0)
- log_emd['v'] = nx.from_numpy(log_emd['v'][:len(b)], type_as=b0)
+ log_emd['partial_w_dist'] = nx.sum(M * gamma)
+ log_emd['u'] = log_emd['u'][:len(a)]
+ log_emd['v'] = log_emd['v'][:len(b)]
if log:
return gamma, log_emd
@@ -389,14 +396,18 @@ def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
NeurIPS.
"""
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(a, b, M)
+
partial_gw, log_w = partial_wasserstein(a, b, M, m, nb_dummies, log=True,
**kwargs)
log_w['T'] = partial_gw
if log:
- return np.sum(partial_gw * M), log_w
+ return nx.sum(partial_gw * M), log_w
else:
- return np.sum(partial_gw * M)
+ return nx.sum(partial_gw * M)
def gwgrad_partial(C1, C2, T):
@@ -838,60 +849,64 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
ot.partial.partial_wasserstein: exact Partial Wasserstein
"""
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(a, b, M)
dim_a, dim_b = M.shape
- dx = np.ones(dim_a, dtype=np.float64)
- dy = np.ones(dim_b, dtype=np.float64)
+ dx = nx.ones(dim_a, type_as=a)
+ dy = nx.ones(dim_b, type_as=b)
if len(a) == 0:
- a = np.ones(dim_a, dtype=np.float64) / dim_a
+ a = nx.ones(dim_a, type_as=a) / dim_a
if len(b) == 0:
- b = np.ones(dim_b, dtype=np.float64) / dim_b
+ b = nx.ones(dim_b, type_as=b) / dim_b
if m is None:
- m = np.min((np.sum(a), np.sum(b))) * 1.0
+ m = nx.min(nx.stack((nx.sum(a), nx.sum(b)))) * 1.0
if m < 0:
raise ValueError("Problem infeasible. Parameter m should be greater"
" than 0.")
- if m > np.min((np.sum(a), np.sum(b))):
+ if m > nx.min(nx.stack((nx.sum(a), nx.sum(b)))):
raise ValueError("Problem infeasible. Parameter m should lower or"
" equal than min(|a|_1, |b|_1).")
log_e = {'err': []}
- # Next 3 lines equivalent to K=np.exp(-M/reg), but faster to compute
- K = np.empty(M.shape, dtype=M.dtype)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
- np.multiply(K, m / np.sum(K), out=K)
+ if type(a) == type(b) == type(M) == np.ndarray:
+ # Next 3 lines equivalent to K=nx.exp(-M/reg), but faster to compute
+ K = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+ np.multiply(K, m / np.sum(K), out=K)
+ else:
+ K = nx.exp(-M / reg)
+ K = K * m / nx.sum(K)
err, cpt = 1, 0
- q1 = np.ones(K.shape)
- q2 = np.ones(K.shape)
- q3 = np.ones(K.shape)
+ q1 = nx.ones(K.shape, type_as=K)
+ q2 = nx.ones(K.shape, type_as=K)
+ q3 = nx.ones(K.shape, type_as=K)
while (err > stopThr and cpt < numItermax):
Kprev = K
K = K * q1
- K1 = np.dot(np.diag(np.minimum(a / np.sum(K, axis=1), dx)), K)
+ K1 = nx.dot(nx.diag(nx.minimum(a / nx.sum(K, axis=1), dx)), K)
q1 = q1 * Kprev / K1
K1prev = K1
K1 = K1 * q2
- K2 = np.dot(K1, np.diag(np.minimum(b / np.sum(K1, axis=0), dy)))
+ K2 = nx.dot(K1, nx.diag(nx.minimum(b / nx.sum(K1, axis=0), dy)))
q2 = q2 * K1prev / K2
K2prev = K2
K2 = K2 * q3
- K = K2 * (m / np.sum(K2))
+ K = K2 * (m / nx.sum(K2))
q3 = q3 * K2prev / K
- if np.any(np.isnan(K)) or np.any(np.isinf(K)):
+ if nx.any(nx.isnan(K)) or nx.any(nx.isinf(K)):
print('Warning: numerical errors at iteration', cpt)
break
if cpt % 10 == 0:
- err = np.linalg.norm(Kprev - K)
+ err = nx.norm(Kprev - K)
if log:
log_e['err'].append(err)
if verbose:
@@ -901,7 +916,7 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
print('{:5d}|{:8e}|'.format(cpt, err))
cpt = cpt + 1
- log_e['partial_w_dist'] = np.sum(M * K)
+ log_e['partial_w_dist'] = nx.sum(M * K)
if log:
return K, log_e
else:
diff --git a/test/test_partial.py b/test/test_partial.py
index ae4a1ab..86f9e62 100755
--- a/test/test_partial.py
+++ b/test/test_partial.py
@@ -8,6 +8,7 @@
import numpy as np
import scipy as sp
import ot
+from ot.backend import to_numpy, torch
import pytest
@@ -82,7 +83,7 @@ def test_partial_wasserstein_lagrange():
w0, log0 = ot.partial.partial_wasserstein_lagrange(p, q, M, 100, log=True)
-def test_partial_wasserstein():
+def test_partial_wasserstein(nx):
n_samples = 20 # nb samples (gaussian)
n_noise = 20 # nb of samples (noise)
@@ -102,25 +103,20 @@ def test_partial_wasserstein():
m = 0.5
+ p, q, M = nx.from_numpy(p, q, M)
+
w0, log0 = ot.partial.partial_wasserstein(p, q, M, m=m, log=True)
- w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m,
- log=True, verbose=True)
+ w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m, log=True, verbose=True)
# check constraints
- np.testing.assert_equal(
- w0.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
- np.testing.assert_equal(
- w0.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein
- np.testing.assert_equal(
- w.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
- np.testing.assert_equal(
- w.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein
+ np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p))
+ np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q))
+ np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p))
+ np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q))
# check transported mass
- np.testing.assert_allclose(
- np.sum(w0), m, atol=1e-04)
- np.testing.assert_allclose(
- np.sum(w), m, atol=1e-04)
+ np.testing.assert_allclose(np.sum(to_numpy(w0)), m, atol=1e-04)
+ np.testing.assert_allclose(np.sum(to_numpy(w)), m, atol=1e-04)
w0, log0 = ot.partial.partial_wasserstein2(p, q, M, m=m, log=True)
w0_val = ot.partial.partial_wasserstein2(p, q, M, m=m, log=False)
@@ -130,12 +126,87 @@ def test_partial_wasserstein():
np.testing.assert_allclose(w0, w0_val, atol=1e-1, rtol=1e-1)
# check constraints
- np.testing.assert_equal(
- G.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
- np.testing.assert_equal(
- G.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein
- np.testing.assert_allclose(
- np.sum(G), m, atol=1e-04)
+ np.testing.assert_equal(to_numpy(nx.sum(G, axis=1) - p) <= 1e-5, [True] * len(p))
+ np.testing.assert_equal(to_numpy(nx.sum(G, axis=0) - q) <= 1e-5, [True] * len(q))
+ np.testing.assert_allclose(np.sum(to_numpy(G)), m, atol=1e-04)
+
+ empty_array = nx.zeros(0, type_as=M)
+ w = ot.partial.partial_wasserstein(empty_array, empty_array, M=M, m=None)
+
+ # check constraints
+ np.testing.assert_equal(to_numpy(nx.sum(w, axis=1) - p) <= 1e-5, [True] * len(p))
+ np.testing.assert_equal(to_numpy(nx.sum(w, axis=0) - q) <= 1e-5, [True] * len(q))
+ np.testing.assert_equal(to_numpy(nx.sum(w, axis=1) - p) <= 1e-5, [True] * len(p))
+ np.testing.assert_equal(to_numpy(nx.sum(w, axis=0) - q) <= 1e-5, [True] * len(q))
+
+ # check transported mass
+ np.testing.assert_allclose(np.sum(to_numpy(w)), 1, atol=1e-04)
+
+ w0 = ot.partial.entropic_partial_wasserstein(empty_array, empty_array, M=M, reg=10, m=None)
+
+ # check constraints
+ np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p))
+ np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q))
+ np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p))
+ np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q))
+
+ # check transported mass
+ np.testing.assert_allclose(np.sum(to_numpy(w0)), 1, atol=1e-04)
+
+
+def test_partial_wasserstein2_gradient():
+ if torch:
+ n_samples = 40
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 2]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+
+ M = torch.tensor(ot.dist(xs, xt), requires_grad=True, dtype=torch.float64)
+
+ p = torch.tensor(ot.unif(n_samples), dtype=torch.float64)
+ q = torch.tensor(ot.unif(n_samples), dtype=torch.float64)
+
+ m = 0.5
+
+ w, log = ot.partial.partial_wasserstein2(p, q, M, m=m, log=True)
+
+ w.backward()
+
+ assert M.grad is not None
+ assert M.grad.shape == M.shape
+
+
+def test_entropic_partial_wasserstein_gradient():
+ if torch:
+ n_samples = 40
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 2]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+
+ M = torch.tensor(ot.dist(xs, xt), requires_grad=True, dtype=torch.float64)
+
+ p = torch.tensor(ot.unif(n_samples), requires_grad=True, dtype=torch.float64)
+ q = torch.tensor(ot.unif(n_samples), requires_grad=True, dtype=torch.float64)
+
+ m = 0.5
+ reg = 1
+
+ _, log = ot.partial.entropic_partial_wasserstein(p, q, M, m=m, reg=reg, log=True)
+
+ log['partial_w_dist'].backward()
+
+ assert M.grad is not None
+ assert p.grad is not None
+ assert q.grad is not None
+ assert M.grad.shape == M.shape
+ assert p.grad.shape == p.shape
+ assert q.grad.shape == q.shape
def test_partial_gromov_wasserstein():