summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClément Bonet <32179275+clbonet@users.noreply.github.com>2023-05-05 10:53:48 +0200
committerGitHub <noreply@github.com>2023-05-05 10:53:48 +0200
commit7e0ea27ad9cad31cfc2181430d837c0a77a61568 (patch)
tree0a41128a975500bfef52a4c21b5af634adecc71a
parent83dc498b496087aea293df1445442d8728435211 (diff)
[MRG] Fix bug SSW backend (#471)
* fix bug np vs torch matmul * typo error * einsum projections ssw * Test broadcast matmul * einsum projections ssw * Test broadcast matmul * projections SSW with einsum * reduce number of samples in test wasserstein_circle_unif * Update releases.md --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
-rw-r--r--RELEASES.md1
-rw-r--r--ot/backend.py23
-rw-r--r--ot/sliced.py33
-rw-r--r--test/test_1d_solver.py6
-rw-r--r--test/test_backend.py15
-rw-r--r--test/test_sliced.py32
6 files changed, 90 insertions, 20 deletions
diff --git a/RELEASES.md b/RELEASES.md
index 586089b..f393883 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -12,6 +12,7 @@
- Major documentation cleanup (PR #462, #467)
- Fix gradients for "Wasserstein2 Minibatch GAN" example (PR #466)
- Faster Bures-Wasserstein distance with NumPy backend (PR #468)
+- Fix issue backend for ot.sliced_wasserstein_sphere ot.sliced_wasserstein_sphere_unif (PR #471)
## 0.9.0
diff --git a/ot/backend.py b/ot/backend.py
index eecf9dd..d661c74 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -959,6 +959,14 @@ class Backend():
"""
raise NotImplementedError()
+ def matmul(self, a, b):
+ r"""
+ Matrix product of two arrays.
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.matmul.html#numpy.matmul
+ """
+ raise NotImplementedError()
+
class NumpyBackend(Backend):
"""
@@ -1293,6 +1301,9 @@ class NumpyBackend(Backend):
return args[0]
return args
+ def matmul(self, a, b):
+ return np.matmul(a, b)
+
class JaxBackend(Backend):
"""
@@ -1645,6 +1656,9 @@ class JaxBackend(Backend):
return jax.lax.stop_gradient((args[0],))[0]
return [jax.lax.stop_gradient((a,))[0] for a in args]
+ def matmul(self, a, b):
+ return jnp.matmul(a, b)
+
class TorchBackend(Backend):
"""
@@ -2098,6 +2112,9 @@ class TorchBackend(Backend):
return args[0].detach()
return [a.detach() for a in args]
+ def matmul(self, a, b):
+ return torch.matmul(a, b)
+
class CupyBackend(Backend): # pragma: no cover
"""
@@ -2474,6 +2491,9 @@ class CupyBackend(Backend): # pragma: no cover
return args[0]
return args
+ def matmul(self, a, b):
+ return cp.matmul(a, b)
+
class TensorflowBackend(Backend):
@@ -2865,3 +2885,6 @@ class TensorflowBackend(Backend):
if len(args) == 1:
return tf.stop_gradient(args[0])
return [tf.stop_gradient(a) for a in args]
+
+ def matmul(self, a, b):
+ return tnp.matmul(a, b)
diff --git a/ot/sliced.py b/ot/sliced.py
index 3a1644d..fd86df9 100644
--- a/ot/sliced.py
+++ b/ot/sliced.py
@@ -260,7 +260,7 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50,
def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50,
- p=2, seed=None, log=False):
+ p=2, projections=None, seed=None, log=False):
r"""
Compute the spherical sliced-Wasserstein discrepancy.
@@ -287,6 +287,8 @@ def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50,
Number of projections used for the Monte-Carlo approximation
p: float, optional (default=2)
Power p used for computing the spherical sliced Wasserstein
+ projections: shape (n_projections, dim, 2), optional
+ Projection matrix (n_projections and seed are not used in this case)
seed: int or RandomState or None, optional
Seed used for random number generator
log: bool, optional
@@ -326,22 +328,25 @@ def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50,
if nx.any(nx.abs(nx.sum(X_s**2, axis=-1) - 1) > 10**(-4)):
raise ValueError("X_s is not on the sphere.")
if nx.any(nx.abs(nx.sum(X_t**2, axis=-1) - 1) > 10**(-4)):
- raise ValueError("Xt is not on the sphere.")
+ raise ValueError("X_t is not on the sphere.")
- # Uniforms and independent samples on the Stiefel manifold V_{d,2}
- if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy':
- Z = seed.randn(n_projections, d, 2)
+ if projections is None:
+ # Uniforms and independent samples on the Stiefel manifold V_{d,2}
+ if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy':
+ Z = seed.randn(n_projections, d, 2)
+ else:
+ if seed is not None:
+ nx.seed(seed)
+ Z = nx.randn(n_projections, d, 2, type_as=X_s)
+
+ projections, _ = nx.qr(Z)
else:
- if seed is not None:
- nx.seed(seed)
- Z = nx.randn(n_projections, d, 2, type_as=X_s)
-
- projections, _ = nx.qr(Z)
+ n_projections = projections.shape[0]
# Projection on S^1
# Projection on plane
- Xps = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_s[:, :, None]), (n_projections, 2, n)), (0, 2, 1))
- Xpt = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_t[:, :, None]), (n_projections, 2, m)), (0, 2, 1))
+ Xps = nx.einsum("ikj, lk -> ilj", projections, X_s)
+ Xpt = nx.einsum("ikj, lk -> ilj", projections, X_t)
# Projection on sphere
Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True))
@@ -425,9 +430,11 @@ def sliced_wasserstein_sphere_unif(X_s, a=None, n_projections=50, seed=None, log
# Projection on S^1
# Projection on plane
- Xps = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_s[:, :, None]), (n_projections, 2, n)), (0, 2, 1))
+ Xps = nx.einsum("ikj, lk -> ilj", projections, X_s)
+
# Projection on sphere
Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True))
+
# Get coordinates on [0,1[
Xps_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xps, (-1, 2))), (n_projections, n))
diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py
index 21abd1d..075a415 100644
--- a/test/test_1d_solver.py
+++ b/test/test_1d_solver.py
@@ -279,7 +279,7 @@ def test_wasserstein1d_circle_devices(nx):
def test_wasserstein_1d_unif_circle():
# test semidiscrete_wasserstein2_unif_circle versus wasserstein_circle
n = 20
- m = 50000
+ m = 1000
rng = np.random.RandomState(0)
u = rng.rand(n,)
@@ -298,8 +298,8 @@ def test_wasserstein_1d_unif_circle():
wass2_unif_circle = ot.semidiscrete_wasserstein2_unif_circle(u, w_u)
# check loss is similar
- np.testing.assert_allclose(wass2, wass2_unif_circle, atol=1e-3)
- np.testing.assert_allclose(wass2_circle, wass2_unif_circle, atol=1e-3)
+ np.testing.assert_allclose(wass2, wass2_unif_circle, atol=1e-2)
+ np.testing.assert_allclose(wass2_circle, wass2_unif_circle, atol=1e-2)
def test_wasserstein1d_unif_circle_devices(nx):
diff --git a/test/test_backend.py b/test/test_backend.py
index 5351e52..fedc62f 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -298,6 +298,8 @@ def test_empty_backend():
nx.transpose(M)
with pytest.raises(NotImplementedError):
nx.detach(M)
+ with pytest.raises(NotImplementedError):
+ nx.matmul(M, M.T)
def test_func_backends(nx):
@@ -308,6 +310,9 @@ def test_func_backends(nx):
v = rnd.randn(3)
val = np.array([1.0])
+ M1 = rnd.randn(1, 2, 10, 10)
+ M2 = rnd.randn(3, 1, 10, 10)
+
# Sparse tensors test
sp_row = np.array([0, 3, 1, 0, 3])
sp_col = np.array([0, 3, 1, 2, 2])
@@ -326,6 +331,9 @@ def test_func_backends(nx):
SquareMb = nx.from_numpy(SquareM)
vb = nx.from_numpy(v)
+ M1b = nx.from_numpy(M1)
+ M2b = nx.from_numpy(M2)
+
val = nx.from_numpy(val)
sp_rowb = nx.from_numpy(sp_row)
@@ -661,6 +669,13 @@ def test_func_backends(nx):
lst_b.append(nx.to_numpy(B))
lst_name.append("detach B")
+ A = nx.matmul(Mb, Mb.T)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("matmul")
+ A = nx.matmul(M1b, M2b)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("matmul broadcast")
+
assert not nx.array_equal(Mb, vb), "array_equal (shape)"
assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true"
assert not nx.array_equal(
diff --git a/test/test_sliced.py b/test/test_sliced.py
index 7b7437a..6d5a27b 100644
--- a/test/test_sliced.py
+++ b/test/test_sliced.py
@@ -295,6 +295,26 @@ def test_sliced_sphere_same_dist():
np.testing.assert_almost_equal(res, 0.)
+def test_sliced_sphere_same_proj():
+ n_projections = 10
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ y = rng.randn(n, 3)
+ y = y / np.sqrt(np.sum(y**2, -1, keepdims=True))
+
+ seed = 42
+
+ cost1, log1 = ot.sliced_wasserstein_sphere(x, y, seed=seed, n_projections=n_projections, log=True)
+ cost2, log2 = ot.sliced_wasserstein_sphere(x, y, seed=seed, n_projections=n_projections, log=True)
+
+ assert np.allclose(log1['projections'], log2['projections'])
+ assert np.isclose(cost1, cost2)
+
+
def test_sliced_sphere_bad_shapes():
n = 100
rng = np.random.RandomState(0)
@@ -398,28 +418,32 @@ def test_sliced_sphere_backend_type_devices(nx):
y = rng.randn(2 * n, 3)
y = y / np.sqrt(np.sum(y**2, -1, keepdims=True))
+ sw_np, log = ot.sliced_wasserstein_sphere(x, y, log=True)
+ P = log["projections"]
+
for tp in nx.__type_list__:
print(nx.dtype_device(tp))
xb, yb = nx.from_numpy(x, y, type_as=tp)
- valb = ot.sliced_wasserstein_sphere(xb, yb)
+ valb = ot.sliced_wasserstein_sphere(xb, yb, projections=nx.from_numpy(P, type_as=tp))
nx.assert_same_dtype_device(xb, valb)
+ np.testing.assert_almost_equal(sw_np, nx.to_numpy(valb))
def test_sliced_sphere_gradient():
if torch:
import torch.nn.functional as F
- X0 = torch.randn((500, 3))
+ X0 = torch.randn((20, 3))
X0 = F.normalize(X0, p=2, dim=-1)
X0.requires_grad_(True)
- X1 = torch.randn((500, 3))
+ X1 = torch.randn((20, 3))
X1 = F.normalize(X1, p=2, dim=-1)
- sw = ot.sliced_wasserstein_sphere(X1, X0, n_projections=500, p=2)
+ sw = ot.sliced_wasserstein_sphere(X1, X0, n_projections=100, p=2)
grad_x0 = torch.autograd.grad(sw, X0)[0]
assert not torch.any(torch.isnan(grad_x0))