diff options
author | Clément Bonet <32179275+clbonet@users.noreply.github.com> | 2023-05-05 10:53:48 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-05 10:53:48 +0200 |
commit | 7e0ea27ad9cad31cfc2181430d837c0a77a61568 (patch) | |
tree | 0a41128a975500bfef52a4c21b5af634adecc71a | |
parent | 83dc498b496087aea293df1445442d8728435211 (diff) |
[MRG] Fix bug SSW backend (#471)
* fix bug np vs torch matmul
* typo error
* einsum projections ssw
* Test broadcast matmul
* einsum projections ssw
* Test broadcast matmul
* projections SSW with einsum
* reduce number of samples in test wasserstein_circle_unif
* Update releases.md
---------
Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
-rw-r--r-- | RELEASES.md | 1 | ||||
-rw-r--r-- | ot/backend.py | 23 | ||||
-rw-r--r-- | ot/sliced.py | 33 | ||||
-rw-r--r-- | test/test_1d_solver.py | 6 | ||||
-rw-r--r-- | test/test_backend.py | 15 | ||||
-rw-r--r-- | test/test_sliced.py | 32 |
6 files changed, 90 insertions, 20 deletions
diff --git a/RELEASES.md b/RELEASES.md index 586089b..f393883 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -12,6 +12,7 @@ - Major documentation cleanup (PR #462, #467) - Fix gradients for "Wasserstein2 Minibatch GAN" example (PR #466) - Faster Bures-Wasserstein distance with NumPy backend (PR #468) +- Fix issue backend for ot.sliced_wasserstein_sphere ot.sliced_wasserstein_sphere_unif (PR #471) ## 0.9.0 diff --git a/ot/backend.py b/ot/backend.py index eecf9dd..d661c74 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -959,6 +959,14 @@ class Backend(): """ raise NotImplementedError() + def matmul(self, a, b): + r""" + Matrix product of two arrays. + + See: https://numpy.org/doc/stable/reference/generated/numpy.matmul.html#numpy.matmul + """ + raise NotImplementedError() + class NumpyBackend(Backend): """ @@ -1293,6 +1301,9 @@ class NumpyBackend(Backend): return args[0] return args + def matmul(self, a, b): + return np.matmul(a, b) + class JaxBackend(Backend): """ @@ -1645,6 +1656,9 @@ class JaxBackend(Backend): return jax.lax.stop_gradient((args[0],))[0] return [jax.lax.stop_gradient((a,))[0] for a in args] + def matmul(self, a, b): + return jnp.matmul(a, b) + class TorchBackend(Backend): """ @@ -2098,6 +2112,9 @@ class TorchBackend(Backend): return args[0].detach() return [a.detach() for a in args] + def matmul(self, a, b): + return torch.matmul(a, b) + class CupyBackend(Backend): # pragma: no cover """ @@ -2474,6 +2491,9 @@ class CupyBackend(Backend): # pragma: no cover return args[0] return args + def matmul(self, a, b): + return cp.matmul(a, b) + class TensorflowBackend(Backend): @@ -2865,3 +2885,6 @@ class TensorflowBackend(Backend): if len(args) == 1: return tf.stop_gradient(args[0]) return [tf.stop_gradient(a) for a in args] + + def matmul(self, a, b): + return tnp.matmul(a, b) diff --git a/ot/sliced.py b/ot/sliced.py index 3a1644d..fd86df9 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -260,7 +260,7 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50, - p=2, seed=None, log=False): + p=2, projections=None, seed=None, log=False): r""" Compute the spherical sliced-Wasserstein discrepancy. @@ -287,6 +287,8 @@ def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50, Number of projections used for the Monte-Carlo approximation p: float, optional (default=2) Power p used for computing the spherical sliced Wasserstein + projections: shape (n_projections, dim, 2), optional + Projection matrix (n_projections and seed are not used in this case) seed: int or RandomState or None, optional Seed used for random number generator log: bool, optional @@ -326,22 +328,25 @@ def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50, if nx.any(nx.abs(nx.sum(X_s**2, axis=-1) - 1) > 10**(-4)): raise ValueError("X_s is not on the sphere.") if nx.any(nx.abs(nx.sum(X_t**2, axis=-1) - 1) > 10**(-4)): - raise ValueError("Xt is not on the sphere.") + raise ValueError("X_t is not on the sphere.") - # Uniforms and independent samples on the Stiefel manifold V_{d,2} - if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy': - Z = seed.randn(n_projections, d, 2) + if projections is None: + # Uniforms and independent samples on the Stiefel manifold V_{d,2} + if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy': + Z = seed.randn(n_projections, d, 2) + else: + if seed is not None: + nx.seed(seed) + Z = nx.randn(n_projections, d, 2, type_as=X_s) + + projections, _ = nx.qr(Z) else: - if seed is not None: - nx.seed(seed) - Z = nx.randn(n_projections, d, 2, type_as=X_s) - - projections, _ = nx.qr(Z) + n_projections = projections.shape[0] # Projection on S^1 # Projection on plane - Xps = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_s[:, :, None]), (n_projections, 2, n)), (0, 2, 1)) - Xpt = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_t[:, :, None]), (n_projections, 2, m)), (0, 2, 1)) + Xps = nx.einsum("ikj, lk -> ilj", projections, X_s) + Xpt = nx.einsum("ikj, lk -> ilj", projections, X_t) # Projection on sphere Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True)) @@ -425,9 +430,11 @@ def sliced_wasserstein_sphere_unif(X_s, a=None, n_projections=50, seed=None, log # Projection on S^1 # Projection on plane - Xps = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_s[:, :, None]), (n_projections, 2, n)), (0, 2, 1)) + Xps = nx.einsum("ikj, lk -> ilj", projections, X_s) + # Projection on sphere Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True)) + # Get coordinates on [0,1[ Xps_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xps, (-1, 2))), (n_projections, n)) diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index 21abd1d..075a415 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -279,7 +279,7 @@ def test_wasserstein1d_circle_devices(nx): def test_wasserstein_1d_unif_circle(): # test semidiscrete_wasserstein2_unif_circle versus wasserstein_circle n = 20 - m = 50000 + m = 1000 rng = np.random.RandomState(0) u = rng.rand(n,) @@ -298,8 +298,8 @@ def test_wasserstein_1d_unif_circle(): wass2_unif_circle = ot.semidiscrete_wasserstein2_unif_circle(u, w_u) # check loss is similar - np.testing.assert_allclose(wass2, wass2_unif_circle, atol=1e-3) - np.testing.assert_allclose(wass2_circle, wass2_unif_circle, atol=1e-3) + np.testing.assert_allclose(wass2, wass2_unif_circle, atol=1e-2) + np.testing.assert_allclose(wass2_circle, wass2_unif_circle, atol=1e-2) def test_wasserstein1d_unif_circle_devices(nx): diff --git a/test/test_backend.py b/test/test_backend.py index 5351e52..fedc62f 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -298,6 +298,8 @@ def test_empty_backend(): nx.transpose(M) with pytest.raises(NotImplementedError): nx.detach(M) + with pytest.raises(NotImplementedError): + nx.matmul(M, M.T) def test_func_backends(nx): @@ -308,6 +310,9 @@ def test_func_backends(nx): v = rnd.randn(3) val = np.array([1.0]) + M1 = rnd.randn(1, 2, 10, 10) + M2 = rnd.randn(3, 1, 10, 10) + # Sparse tensors test sp_row = np.array([0, 3, 1, 0, 3]) sp_col = np.array([0, 3, 1, 2, 2]) @@ -326,6 +331,9 @@ def test_func_backends(nx): SquareMb = nx.from_numpy(SquareM) vb = nx.from_numpy(v) + M1b = nx.from_numpy(M1) + M2b = nx.from_numpy(M2) + val = nx.from_numpy(val) sp_rowb = nx.from_numpy(sp_row) @@ -661,6 +669,13 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(B)) lst_name.append("detach B") + A = nx.matmul(Mb, Mb.T) + lst_b.append(nx.to_numpy(A)) + lst_name.append("matmul") + A = nx.matmul(M1b, M2b) + lst_b.append(nx.to_numpy(A)) + lst_name.append("matmul broadcast") + assert not nx.array_equal(Mb, vb), "array_equal (shape)" assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true" assert not nx.array_equal( diff --git a/test/test_sliced.py b/test/test_sliced.py index 7b7437a..6d5a27b 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -295,6 +295,26 @@ def test_sliced_sphere_same_dist(): np.testing.assert_almost_equal(res, 0.) +def test_sliced_sphere_same_proj(): + n_projections = 10 + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + y = rng.randn(n, 3) + y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + + seed = 42 + + cost1, log1 = ot.sliced_wasserstein_sphere(x, y, seed=seed, n_projections=n_projections, log=True) + cost2, log2 = ot.sliced_wasserstein_sphere(x, y, seed=seed, n_projections=n_projections, log=True) + + assert np.allclose(log1['projections'], log2['projections']) + assert np.isclose(cost1, cost2) + + def test_sliced_sphere_bad_shapes(): n = 100 rng = np.random.RandomState(0) @@ -398,28 +418,32 @@ def test_sliced_sphere_backend_type_devices(nx): y = rng.randn(2 * n, 3) y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + sw_np, log = ot.sliced_wasserstein_sphere(x, y, log=True) + P = log["projections"] + for tp in nx.__type_list__: print(nx.dtype_device(tp)) xb, yb = nx.from_numpy(x, y, type_as=tp) - valb = ot.sliced_wasserstein_sphere(xb, yb) + valb = ot.sliced_wasserstein_sphere(xb, yb, projections=nx.from_numpy(P, type_as=tp)) nx.assert_same_dtype_device(xb, valb) + np.testing.assert_almost_equal(sw_np, nx.to_numpy(valb)) def test_sliced_sphere_gradient(): if torch: import torch.nn.functional as F - X0 = torch.randn((500, 3)) + X0 = torch.randn((20, 3)) X0 = F.normalize(X0, p=2, dim=-1) X0.requires_grad_(True) - X1 = torch.randn((500, 3)) + X1 = torch.randn((20, 3)) X1 = F.normalize(X1, p=2, dim=-1) - sw = ot.sliced_wasserstein_sphere(X1, X0, n_projections=500, p=2) + sw = ot.sliced_wasserstein_sphere(X1, X0, n_projections=100, p=2) grad_x0 = torch.autograd.grad(sw, X0)[0] assert not torch.any(torch.isnan(grad_x0)) |