diff options
author | Gard Spreemann <gspr@nonempty.org> | 2022-04-27 11:49:23 +0200 |
---|---|---|
committer | Gard Spreemann <gspr@nonempty.org> | 2022-04-27 11:49:23 +0200 |
commit | 35bd2c98b642df78638d7d733bc1a89d873db1de (patch) | |
tree | 6bc637624004713808d3097b95acdccbb9608e52 /test | |
parent | c4753bd3f74139af8380127b66b484bc09b50661 (diff) | |
parent | eccb1386eea52b94b82456d126bd20cbe3198e05 (diff) |
Merge tag '0.8.2' into dfsg/latest
Diffstat (limited to 'test')
-rw-r--r-- | test/test_1d_solver.py | 28 | ||||
-rw-r--r-- | test/test_backend.py | 66 | ||||
-rw-r--r-- | test/test_bregman.py | 94 | ||||
-rw-r--r-- | test/test_da.py | 307 | ||||
-rw-r--r-- | test/test_dr.py | 22 | ||||
-rw-r--r-- | test/test_factored.py | 56 | ||||
-rw-r--r-- | test/test_gpu.py | 106 | ||||
-rw-r--r-- | test/test_gromov.py | 726 | ||||
-rw-r--r-- | test/test_optim.py | 17 | ||||
-rw-r--r-- | test/test_ot.py | 42 | ||||
-rw-r--r-- | test/test_sliced.py | 32 | ||||
-rw-r--r-- | test/test_stochastic.py | 115 | ||||
-rw-r--r-- | test/test_unbalanced.py | 207 | ||||
-rw-r--r-- | test/test_utils.py | 22 | ||||
-rw-r--r-- | test/test_weak.py | 52 |
15 files changed, 1310 insertions, 582 deletions
diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index 6a42cfe..20f307a 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -66,9 +66,7 @@ def test_wasserstein_1d(nx): rho_v = np.abs(rng.randn(n)) rho_v /= rho_v.sum() - xb = nx.from_numpy(x) - rho_ub = nx.from_numpy(rho_u) - rho_vb = nx.from_numpy(rho_v) + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v) # test 1 : wasserstein_1d should be close to scipy W_1 implementation np.testing.assert_almost_equal(wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1), @@ -98,9 +96,7 @@ def test_wasserstein_1d_type_devices(nx): for tp in nx.__type_list__: print(nx.dtype_device(tp)) - xb = nx.from_numpy(x, type_as=tp) - rho_ub = nx.from_numpy(rho_u, type_as=tp) - rho_vb = nx.from_numpy(rho_v, type_as=tp) + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1) @@ -122,17 +118,13 @@ def test_wasserstein_1d_device_tf(): # Check that everything stays on the CPU with tf.device("/CPU:0"): - xb = nx.from_numpy(x) - rho_ub = nx.from_numpy(rho_u) - rho_vb = nx.from_numpy(rho_v) + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v) res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1) nx.assert_same_dtype_device(xb, res) if len(tf.config.list_physical_devices('GPU')) > 0: # Check that everything happens on the GPU - xb = nx.from_numpy(x) - rho_ub = nx.from_numpy(rho_u) - rho_vb = nx.from_numpy(rho_v) + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v) res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1) nx.assert_same_dtype_device(xb, res) assert nx.dtype_device(res)[1].startswith("GPU") @@ -190,9 +182,7 @@ def test_emd1d_type_devices(nx): for tp in nx.__type_list__: print(nx.dtype_device(tp)) - xb = nx.from_numpy(x, type_as=tp) - rho_ub = nx.from_numpy(rho_u, type_as=tp) - rho_vb = nx.from_numpy(rho_v, type_as=tp) + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) emd = ot.emd_1d(xb, xb, rho_ub, rho_vb) emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb) @@ -214,9 +204,7 @@ def test_emd1d_device_tf(): # Check that everything stays on the CPU with tf.device("/CPU:0"): - xb = nx.from_numpy(x) - rho_ub = nx.from_numpy(rho_u) - rho_vb = nx.from_numpy(rho_v) + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v) emd = ot.emd_1d(xb, xb, rho_ub, rho_vb) emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb) nx.assert_same_dtype_device(xb, emd) @@ -224,9 +212,7 @@ def test_emd1d_device_tf(): if len(tf.config.list_physical_devices('GPU')) > 0: # Check that everything happens on the GPU - xb = nx.from_numpy(x) - rho_ub = nx.from_numpy(rho_u) - rho_vb = nx.from_numpy(rho_v) + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v) emd = ot.emd_1d(xb, xb, rho_ub, rho_vb) emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb) nx.assert_same_dtype_device(xb, emd) diff --git a/test/test_backend.py b/test/test_backend.py index 027c4cd..311c075 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -218,6 +218,8 @@ def test_empty_backend(): with pytest.raises(NotImplementedError): nx.argmax(M) with pytest.raises(NotImplementedError): + nx.argmin(M) + with pytest.raises(NotImplementedError): nx.mean(M) with pytest.raises(NotImplementedError): nx.std(M) @@ -264,12 +266,27 @@ def test_empty_backend(): nx.device_type(M) with pytest.raises(NotImplementedError): nx._bench(lambda x: x, M, n_runs=1) + with pytest.raises(NotImplementedError): + nx.solve(M, v) + with pytest.raises(NotImplementedError): + nx.trace(M) + with pytest.raises(NotImplementedError): + nx.inv(M) + with pytest.raises(NotImplementedError): + nx.sqrtm(M) + with pytest.raises(NotImplementedError): + nx.isfinite(M) + with pytest.raises(NotImplementedError): + nx.array_equal(M, M) + with pytest.raises(NotImplementedError): + nx.is_floating_point(M) def test_func_backends(nx): rnd = np.random.RandomState(0) M = rnd.randn(10, 3) + SquareM = rnd.randn(10, 10) v = rnd.randn(3) val = np.array([1.0]) @@ -288,6 +305,7 @@ def test_func_backends(nx): lst_name = [] Mb = nx.from_numpy(M) + SquareMb = nx.from_numpy(SquareM) vb = nx.from_numpy(v) val = nx.from_numpy(val) @@ -467,6 +485,10 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append('argmax') + A = nx.argmin(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('argmin') + A = nx.mean(Mb) lst_b.append(nx.to_numpy(A)) lst_name.append('mean') @@ -529,7 +551,11 @@ def test_func_backends(nx): A = nx.where(Mb >= nx.stack([nx.linspace(0, 1, 10)] * 3, axis=1), Mb, 0.0) lst_b.append(nx.to_numpy(A)) - lst_name.append('where') + lst_name.append('where (cond, x, y)') + + A = nx.where(nx.from_numpy(np.array([True, False]))) + lst_b.append(nx.to_numpy(nx.stack(A))) + lst_name.append('where (cond)') A = nx.copy(Mb) lst_b.append(nx.to_numpy(A)) @@ -550,15 +576,47 @@ def test_func_backends(nx): nx._bench(lambda x: x, M, n_runs=1) + A = nx.solve(SquareMb, Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('solve') + + A = nx.trace(SquareMb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('trace') + + A = nx.inv(SquareMb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('matrix inverse') + + A = nx.sqrtm(SquareMb.T @ SquareMb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("matrix square root") + + A = nx.concatenate([vb, nx.from_numpy(np.array([np.inf, np.nan]))], axis=0) + A = nx.isfinite(A) + lst_b.append(nx.to_numpy(A)) + lst_name.append("isfinite") + + assert not nx.array_equal(Mb, vb), "array_equal (shape)" + assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true" + assert not nx.array_equal( + Mb, Mb + nx.eye(*list(Mb.shape)) + ), "array_equal (elements) - expected false" + + assert nx.is_floating_point(Mb), "is_floating_point - expected true" + assert not nx.is_floating_point( + nx.from_numpy(np.array([0, 1, 2], dtype=int)) + ), "is_floating_point - expected false" + lst_tot.append(lst_b) lst_np = lst_tot[0] lst_b = lst_tot[1] for a1, a2, name in zip(lst_np, lst_b, lst_name): - if not np.allclose(a1, a2): - print('Assert fail on: ', name) - assert np.allclose(a1, a2, atol=1e-7) + np.testing.assert_allclose( + a2, a1, atol=1e-7, err_msg=f'ASSERT FAILED ON: {name}' + ) def test_random_backends(nx): diff --git a/test/test_bregman.py b/test/test_bregman.py index 6e90aa4..6c37984 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -60,7 +60,7 @@ def test_convergence_warning(method): ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1) -def test_not_impemented_method(): +def test_not_implemented_method(): # test sinkhorn w = 10 n = w ** 2 @@ -155,8 +155,7 @@ def test_sinkhorn_backends(nx): G = ot.sinkhorn(a, a, M, 1) - ab = nx.from_numpy(a) - M_nx = nx.from_numpy(M) + ab, M_nx = nx.from_numpy(a, M) Gb = ot.sinkhorn(ab, ab, M_nx, 1) @@ -176,8 +175,7 @@ def test_sinkhorn2_backends(nx): G = ot.sinkhorn(a, a, M, 1) - ab = nx.from_numpy(a) - M_nx = nx.from_numpy(M) + ab, M_nx = nx.from_numpy(a, M) Gb = ot.sinkhorn2(ab, ab, M_nx, 1) @@ -260,8 +258,7 @@ def test_sinkhorn_variants(nx): M = ot.dist(x, x) - ub = nx.from_numpy(u) - M_nx = nx.from_numpy(M) + ub, M_nx = nx.from_numpy(u, M) G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10) Gl = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) @@ -298,8 +295,7 @@ def test_sinkhorn_variants_dtype_device(nx, method): for tp in nx.__type_list__: print(nx.dtype_device(tp)) - ub = nx.from_numpy(u, type_as=tp) - Mb = nx.from_numpy(M, type_as=tp) + ub, Mb = nx.from_numpy(u, M, type_as=tp) Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10) @@ -318,8 +314,7 @@ def test_sinkhorn2_variants_dtype_device(nx, method): for tp in nx.__type_list__: print(nx.dtype_device(tp)) - ub = nx.from_numpy(u, type_as=tp) - Mb = nx.from_numpy(M, type_as=tp) + ub, Mb = nx.from_numpy(u, M, type_as=tp) lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10) @@ -337,8 +332,7 @@ def test_sinkhorn2_variants_device_tf(method): # Check that everything stays on the CPU with tf.device("/CPU:0"): - ub = nx.from_numpy(u) - Mb = nx.from_numpy(M) + ub, Mb = nx.from_numpy(u, M) Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10) lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10) nx.assert_same_dtype_device(Mb, Gb) @@ -346,8 +340,7 @@ def test_sinkhorn2_variants_device_tf(method): if len(tf.config.list_physical_devices('GPU')) > 0: # Check that everything happens on the GPU - ub = nx.from_numpy(u) - Mb = nx.from_numpy(M) + ub, Mb = nx.from_numpy(u, M) Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10) lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10) nx.assert_same_dtype_device(Mb, Gb) @@ -370,9 +363,7 @@ def test_sinkhorn_variants_multi_b(nx): M = ot.dist(x, x) - ub = nx.from_numpy(u) - bb = nx.from_numpy(b) - M_nx = nx.from_numpy(M) + ub, bb, M_nx = nx.from_numpy(u, b, M) G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10) Gl = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) @@ -400,9 +391,7 @@ def test_sinkhorn2_variants_multi_b(nx): M = ot.dist(x, x) - ub = nx.from_numpy(u) - bb = nx.from_numpy(b) - M_nx = nx.from_numpy(M) + ub, bb, M_nx = nx.from_numpy(u, b, M) G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10) Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) @@ -483,9 +472,7 @@ def test_barycenter(nx, method, verbose, warn): alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) - A_nx = nx.from_numpy(A) - M_nx = nx.from_numpy(M) - weights_nx = nx.from_numpy(weights) + A_nx, M_nx, weights_nx = nx.from_numpy(A, M, weights) reg = 1e-2 if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": @@ -523,9 +510,7 @@ def test_barycenter_debiased(nx, method, verbose, warn): alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) - A_nx = nx.from_numpy(A) - M_nx = nx.from_numpy(M) - weights_nx = nx.from_numpy(weights) + A_nx, M_nx, weights_nx = nx.from_numpy(A, M, weights) # wasserstein reg = 1e-2 @@ -594,9 +579,7 @@ def test_barycenter_stabilization(nx): alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) - A_nx = nx.from_numpy(A) - M_nx = nx.from_numpy(M) - weights_b = nx.from_numpy(weights) + A_nx, M_nx, weights_b = nx.from_numpy(A, M, weights) # wasserstein reg = 1e-2 @@ -635,7 +618,7 @@ def test_wasserstein_bary_2d(nx, method): with pytest.raises(NotImplementedError): ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method) else: - bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method) + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method, verbose=True, log=True) bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)) np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) @@ -667,7 +650,7 @@ def test_wasserstein_bary_2d_debiased(nx, method): with pytest.raises(NotImplementedError): ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method) else: - bary_wass_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method) + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method, verbose=True, log=True) bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)) np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) @@ -697,11 +680,7 @@ def test_unmix(nx): M0 /= M0.max() h0 = ot.unif(2) - ab = nx.from_numpy(a) - Db = nx.from_numpy(D) - M_nx = nx.from_numpy(M) - M0b = nx.from_numpy(M0) - h0b = nx.from_numpy(h0) + ab, Db, M_nx, M0b, h0b = nx.from_numpy(a, D, M, M0, h0) # wasserstein reg = 1e-3 @@ -727,12 +706,7 @@ def test_empirical_sinkhorn(nx): M = ot.dist(X_s, X_t) M_m = ot.dist(X_s, X_t, metric='euclidean') - ab = nx.from_numpy(a) - bb = nx.from_numpy(b) - X_sb = nx.from_numpy(X_s) - X_tb = nx.from_numpy(X_t) - M_nx = nx.from_numpy(M, type_as=ab) - M_mb = nx.from_numpy(M_m, type_as=ab) + ab, bb, X_sb, X_tb, M_nx, M_mb = nx.from_numpy(a, b, X_s, X_t, M, M_m) G_sqe = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1)) sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1)) @@ -776,12 +750,7 @@ def test_lazy_empirical_sinkhorn(nx): M = ot.dist(X_s, X_t) M_m = ot.dist(X_s, X_t, metric='euclidean') - ab = nx.from_numpy(a) - bb = nx.from_numpy(b) - X_sb = nx.from_numpy(X_s) - X_tb = nx.from_numpy(X_t) - M_nx = nx.from_numpy(M, type_as=ab) - M_mb = nx.from_numpy(M_m, type_as=ab) + ab, bb, X_sb, X_tb, M_nx, M_mb = nx.from_numpy(a, b, X_s, X_t, M, M_m) f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True) f, g = nx.to_numpy(f), nx.to_numpy(g) @@ -825,19 +794,13 @@ def test_empirical_sinkhorn_divergence(nx): a = np.linspace(1, n, n) a /= a.sum() b = ot.unif(n) - X_s = np.reshape(np.arange(n), (n, 1)) - X_t = np.reshape(np.arange(0, n * 2, 2), (n, 1)) + X_s = np.reshape(np.arange(n, dtype=np.float64), (n, 1)) + X_t = np.reshape(np.arange(0, n * 2, 2, dtype=np.float64), (n, 1)) M = ot.dist(X_s, X_t) M_s = ot.dist(X_s, X_s) M_t = ot.dist(X_t, X_t) - ab = nx.from_numpy(a) - bb = nx.from_numpy(b) - X_sb = nx.from_numpy(X_s) - X_tb = nx.from_numpy(X_t) - M_nx = nx.from_numpy(M, type_as=ab) - M_sb = nx.from_numpy(M_s, type_as=ab) - M_tb = nx.from_numpy(M_t, type_as=ab) + ab, bb, X_sb, X_tb, M_nx, M_sb, M_tb = nx.from_numpy(a, b, X_s, X_t, M, M_s, M_t) emp_sinkhorn_div = nx.to_numpy(ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb)) sinkhorn_div = nx.to_numpy( @@ -872,9 +835,7 @@ def test_stabilized_vs_sinkhorn_multidim(nx): M /= np.median(M) epsilon = 0.1 - ab = nx.from_numpy(a) - bb = nx.from_numpy(b) - M_nx = nx.from_numpy(M, type_as=ab) + ab, bb, M_nx = nx.from_numpy(a, b, M) G_np, _ = ot.bregman.sinkhorn(a, b, M, reg=epsilon, method="sinkhorn", log=True) G, log = ot.bregman.sinkhorn(ab, bb, M_nx, reg=epsilon, @@ -936,18 +897,13 @@ def test_screenkhorn(nx): x = rng.randn(n, 2) M = ot.dist(x, x) - ab = nx.from_numpy(a) - bb = nx.from_numpy(b) - M_nx = nx.from_numpy(M, type_as=ab) + ab, bb, M_nx = nx.from_numpy(a, b, M) - # np sinkhorn - G_sink_np = ot.sinkhorn(a, b, M, 1e-03) # sinkhorn - G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-03)) + G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-1)) # screenkhorn - G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-03, uniform=True, verbose=True)) + G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-1, uniform=True, verbose=True)) # check marginals - np.testing.assert_allclose(G_sink_np, G_sink) np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02) np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02) diff --git a/test/test_da.py b/test/test_da.py index 9f2bb50..4bf0ab1 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -19,7 +19,32 @@ except ImportError: nosklearn = True -def test_sinkhorn_lpl1_transport_class(): +def test_class_jax_tf(): + backends = [] + from ot.backend import jax, tf + if jax: + backends.append(ot.backend.JaxBackend()) + if tf: + backends.append(ot.backend.TensorflowBackend()) + + for nx in backends: + ns = 150 + nt = 200 + + Xs, ys = make_data_classif('3gauss', ns) + Xt, yt = make_data_classif('3gauss2', nt) + + Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) + + otda = ot.da.SinkhornLpl1Transport() + + with pytest.raises(TypeError): + otda.fit(Xs=Xs, ys=ys, Xt=Xt) + + +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_sinkhorn_lpl1_transport_class(nx): """test_sinkhorn_transport """ @@ -29,6 +54,8 @@ def test_sinkhorn_lpl1_transport_class(): Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) + Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) + otda = ot.da.SinkhornLpl1Transport() # test its computed @@ -44,15 +71,15 @@ def test_sinkhorn_lpl1_transport_class(): mu_s = unif(ns) mu_t = unif(nt) assert_allclose( - np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3) assert_allclose( - np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3) # test transform transp_Xs = otda.transform(Xs=Xs) assert_equal(transp_Xs.shape, Xs.shape) - Xs_new, _ = make_data_classif('3gauss', ns + 1) + Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0]) transp_Xs_new = otda.transform(Xs_new) # check that the oos method is working @@ -62,7 +89,7 @@ def test_sinkhorn_lpl1_transport_class(): transp_Xt = otda.inverse_transform(Xt=Xt) assert_equal(transp_Xt.shape, Xt.shape) - Xt_new, _ = make_data_classif('3gauss2', nt + 1) + Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0]) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) # check that the oos method is working @@ -85,24 +112,26 @@ def test_sinkhorn_lpl1_transport_class(): # test unsupervised vs semi-supervised mode otda_unsup = ot.da.SinkhornLpl1Transport() otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt) - n_unsup = np.sum(otda_unsup.cost_) + n_unsup = nx.sum(otda_unsup.cost_) otda_semi = ot.da.SinkhornLpl1Transport() otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) - n_semisup = np.sum(otda_semi.cost_) + n_semisup = nx.sum(otda_semi.cost_) # check that the cost matrix norms are indeed different assert n_unsup != n_semisup, "semisupervised mode not working" # check that the coupling forbids mass transport between labeled source # and labeled target samples - mass_semi = np.sum( + mass_semi = nx.sum( otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]) assert mass_semi == 0, "semisupervised mode not working" -def test_sinkhorn_l1l2_transport_class(): +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_sinkhorn_l1l2_transport_class(nx): """test_sinkhorn_transport """ @@ -112,6 +141,8 @@ def test_sinkhorn_l1l2_transport_class(): Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) + Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) + otda = ot.da.SinkhornL1l2Transport() # test its computed @@ -128,15 +159,15 @@ def test_sinkhorn_l1l2_transport_class(): mu_s = unif(ns) mu_t = unif(nt) assert_allclose( - np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3) assert_allclose( - np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3) # test transform transp_Xs = otda.transform(Xs=Xs) assert_equal(transp_Xs.shape, Xs.shape) - Xs_new, _ = make_data_classif('3gauss', ns + 1) + Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0]) transp_Xs_new = otda.transform(Xs_new) # check that the oos method is working @@ -156,7 +187,7 @@ def test_sinkhorn_l1l2_transport_class(): assert_equal(transp_ys.shape[0], ys.shape[0]) assert_equal(transp_ys.shape[1], len(np.unique(yt))) - Xt_new, _ = make_data_classif('3gauss2', nt + 1) + Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0]) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) # check that the oos method is working @@ -169,22 +200,22 @@ def test_sinkhorn_l1l2_transport_class(): # test unsupervised vs semi-supervised mode otda_unsup = ot.da.SinkhornL1l2Transport() otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt) - n_unsup = np.sum(otda_unsup.cost_) + n_unsup = nx.sum(otda_unsup.cost_) otda_semi = ot.da.SinkhornL1l2Transport() otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) - n_semisup = np.sum(otda_semi.cost_) + n_semisup = nx.sum(otda_semi.cost_) # check that the cost matrix norms are indeed different assert n_unsup != n_semisup, "semisupervised mode not working" # check that the coupling forbids mass transport between labeled source # and labeled target samples - mass_semi = np.sum( + mass_semi = nx.sum( otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]) mass_semi = otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max] - assert_allclose(mass_semi, np.zeros_like(mass_semi), + assert_allclose(nx.to_numpy(mass_semi), np.zeros(list(mass_semi.shape)), rtol=1e-9, atol=1e-9) # check everything runs well with log=True @@ -193,7 +224,9 @@ def test_sinkhorn_l1l2_transport_class(): assert len(otda.log_.keys()) != 0 -def test_sinkhorn_transport_class(): +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_sinkhorn_transport_class(nx): """test_sinkhorn_transport """ @@ -203,6 +236,8 @@ def test_sinkhorn_transport_class(): Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) + Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) + otda = ot.da.SinkhornTransport() # test its computed @@ -219,15 +254,15 @@ def test_sinkhorn_transport_class(): mu_s = unif(ns) mu_t = unif(nt) assert_allclose( - np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3) assert_allclose( - np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3) # test transform transp_Xs = otda.transform(Xs=Xs) assert_equal(transp_Xs.shape, Xs.shape) - Xs_new, _ = make_data_classif('3gauss', ns + 1) + Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0]) transp_Xs_new = otda.transform(Xs_new) # check that the oos method is working @@ -247,7 +282,7 @@ def test_sinkhorn_transport_class(): assert_equal(transp_ys.shape[0], ys.shape[0]) assert_equal(transp_ys.shape[1], len(np.unique(yt))) - Xt_new, _ = make_data_classif('3gauss2', nt + 1) + Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0]) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) # check that the oos method is working @@ -260,19 +295,19 @@ def test_sinkhorn_transport_class(): # test unsupervised vs semi-supervised mode otda_unsup = ot.da.SinkhornTransport() otda_unsup.fit(Xs=Xs, Xt=Xt) - n_unsup = np.sum(otda_unsup.cost_) + n_unsup = nx.sum(otda_unsup.cost_) otda_semi = ot.da.SinkhornTransport() otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) - n_semisup = np.sum(otda_semi.cost_) + n_semisup = nx.sum(otda_semi.cost_) # check that the cost matrix norms are indeed different assert n_unsup != n_semisup, "semisupervised mode not working" # check that the coupling forbids mass transport between labeled source # and labeled target samples - mass_semi = np.sum( + mass_semi = nx.sum( otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]) assert mass_semi == 0, "semisupervised mode not working" @@ -282,7 +317,9 @@ def test_sinkhorn_transport_class(): assert len(otda.log_.keys()) != 0 -def test_unbalanced_sinkhorn_transport_class(): +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_unbalanced_sinkhorn_transport_class(nx): """test_sinkhorn_transport """ @@ -292,6 +329,8 @@ def test_unbalanced_sinkhorn_transport_class(): Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) + Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) + otda = ot.da.UnbalancedSinkhornTransport() # test its computed @@ -318,7 +357,7 @@ def test_unbalanced_sinkhorn_transport_class(): assert_equal(transp_ys.shape[0], ys.shape[0]) assert_equal(transp_ys.shape[1], len(np.unique(yt))) - Xs_new, _ = make_data_classif('3gauss', ns + 1) + Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0]) transp_Xs_new = otda.transform(Xs_new) # check that the oos method is working @@ -328,7 +367,7 @@ def test_unbalanced_sinkhorn_transport_class(): transp_Xt = otda.inverse_transform(Xt=Xt) assert_equal(transp_Xt.shape, Xt.shape) - Xt_new, _ = make_data_classif('3gauss2', nt + 1) + Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0]) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) # check that the oos method is working @@ -341,12 +380,12 @@ def test_unbalanced_sinkhorn_transport_class(): # test unsupervised vs semi-supervised mode otda_unsup = ot.da.SinkhornTransport() otda_unsup.fit(Xs=Xs, Xt=Xt) - n_unsup = np.sum(otda_unsup.cost_) + n_unsup = nx.sum(otda_unsup.cost_) otda_semi = ot.da.SinkhornTransport() otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) - n_semisup = np.sum(otda_semi.cost_) + n_semisup = nx.sum(otda_semi.cost_) # check that the cost matrix norms are indeed different assert n_unsup != n_semisup, "semisupervised mode not working" @@ -357,7 +396,9 @@ def test_unbalanced_sinkhorn_transport_class(): assert len(otda.log_.keys()) != 0 -def test_emd_transport_class(): +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_emd_transport_class(nx): """test_sinkhorn_transport """ @@ -367,6 +408,8 @@ def test_emd_transport_class(): Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) + Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) + otda = ot.da.EMDTransport() # test its computed @@ -382,15 +425,15 @@ def test_emd_transport_class(): mu_s = unif(ns) mu_t = unif(nt) assert_allclose( - np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3) assert_allclose( - np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3) # test transform transp_Xs = otda.transform(Xs=Xs) assert_equal(transp_Xs.shape, Xs.shape) - Xs_new, _ = make_data_classif('3gauss', ns + 1) + Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0]) transp_Xs_new = otda.transform(Xs_new) # check that the oos method is working @@ -410,7 +453,7 @@ def test_emd_transport_class(): assert_equal(transp_ys.shape[0], ys.shape[0]) assert_equal(transp_ys.shape[1], len(np.unique(yt))) - Xt_new, _ = make_data_classif('3gauss2', nt + 1) + Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0]) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) # check that the oos method is working @@ -423,28 +466,32 @@ def test_emd_transport_class(): # test unsupervised vs semi-supervised mode otda_unsup = ot.da.EMDTransport() otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt) - n_unsup = np.sum(otda_unsup.cost_) + n_unsup = nx.sum(otda_unsup.cost_) otda_semi = ot.da.EMDTransport() otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) - n_semisup = np.sum(otda_semi.cost_) + n_semisup = nx.sum(otda_semi.cost_) # check that the cost matrix norms are indeed different assert n_unsup != n_semisup, "semisupervised mode not working" # check that the coupling forbids mass transport between labeled source # and labeled target samples - mass_semi = np.sum( + mass_semi = nx.sum( otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]) mass_semi = otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max] # we need to use a small tolerance here, otherwise the test breaks - assert_allclose(mass_semi, np.zeros_like(mass_semi), + assert_allclose(nx.to_numpy(mass_semi), np.zeros(list(mass_semi.shape)), rtol=1e-2, atol=1e-2) -def test_mapping_transport_class(): +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +@pytest.mark.parametrize("kernel", ["linear", "gaussian"]) +@pytest.mark.parametrize("bias", ["unbiased", "biased"]) +def test_mapping_transport_class(nx, kernel, bias): """test_mapping_transport """ @@ -455,101 +502,29 @@ def test_mapping_transport_class(): Xt, yt = make_data_classif('3gauss2', nt) Xs_new, _ = make_data_classif('3gauss', ns + 1) - ########################################################################## - # kernel == linear mapping tests - ########################################################################## + Xs, Xt, Xs_new = nx.from_numpy(Xs, Xt, Xs_new) - # check computation and dimensions if bias == False - otda = ot.da.MappingTransport(kernel="linear", bias=False) + # Mapping tests + bias = bias == "biased" + otda = ot.da.MappingTransport(kernel=kernel, bias=bias) otda.fit(Xs=Xs, Xt=Xt) assert hasattr(otda, "coupling_") assert hasattr(otda, "mapping_") assert hasattr(otda, "log_") assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0]))) - assert_equal(otda.mapping_.shape, ((Xs.shape[1], Xt.shape[1]))) + S = Xs.shape[0] if kernel == "gaussian" else Xs.shape[1] # if linear + if bias: + S += 1 + assert_equal(otda.mapping_.shape, ((S, Xt.shape[1]))) # test margin constraints mu_s = unif(ns) mu_t = unif(nt) assert_allclose( - np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3) assert_allclose( - np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) - - # test transform - transp_Xs = otda.transform(Xs=Xs) - assert_equal(transp_Xs.shape, Xs.shape) - - transp_Xs_new = otda.transform(Xs_new) - - # check that the oos method is working - assert_equal(transp_Xs_new.shape, Xs_new.shape) - - # check computation and dimensions if bias == True - otda = ot.da.MappingTransport(kernel="linear", bias=True) - otda.fit(Xs=Xs, Xt=Xt) - assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0]))) - assert_equal(otda.mapping_.shape, ((Xs.shape[1] + 1, Xt.shape[1]))) - - # test margin constraints - mu_s = unif(ns) - mu_t = unif(nt) - assert_allclose( - np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) - assert_allclose( - np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) - - # test transform - transp_Xs = otda.transform(Xs=Xs) - assert_equal(transp_Xs.shape, Xs.shape) - - transp_Xs_new = otda.transform(Xs_new) - - # check that the oos method is working - assert_equal(transp_Xs_new.shape, Xs_new.shape) - - ########################################################################## - # kernel == gaussian mapping tests - ########################################################################## - - # check computation and dimensions if bias == False - otda = ot.da.MappingTransport(kernel="gaussian", bias=False) - otda.fit(Xs=Xs, Xt=Xt) - - assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0]))) - assert_equal(otda.mapping_.shape, ((Xs.shape[0], Xt.shape[1]))) - - # test margin constraints - mu_s = unif(ns) - mu_t = unif(nt) - assert_allclose( - np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) - assert_allclose( - np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) - - # test transform - transp_Xs = otda.transform(Xs=Xs) - assert_equal(transp_Xs.shape, Xs.shape) - - transp_Xs_new = otda.transform(Xs_new) - - # check that the oos method is working - assert_equal(transp_Xs_new.shape, Xs_new.shape) - - # check computation and dimensions if bias == True - otda = ot.da.MappingTransport(kernel="gaussian", bias=True) - otda.fit(Xs=Xs, Xt=Xt) - assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0]))) - assert_equal(otda.mapping_.shape, ((Xs.shape[0] + 1, Xt.shape[1]))) - - # test margin constraints - mu_s = unif(ns) - mu_t = unif(nt) - assert_allclose( - np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) - assert_allclose( - np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3) # test transform transp_Xs = otda.transform(Xs=Xs) @@ -561,29 +536,39 @@ def test_mapping_transport_class(): assert_equal(transp_Xs_new.shape, Xs_new.shape) # check everything runs well with log=True - otda = ot.da.MappingTransport(kernel="gaussian", log=True) + otda = ot.da.MappingTransport(kernel=kernel, bias=bias, log=True) otda.fit(Xs=Xs, Xt=Xt) assert len(otda.log_.keys()) != 0 + +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_mapping_transport_class_specific_seed(nx): # check that it does not crash when derphi is very close to 0 + ns = 20 + nt = 30 np.random.seed(39) Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) otda = ot.da.MappingTransport(kernel="gaussian", bias=False) - otda.fit(Xs=Xs, Xt=Xt) + otda.fit(Xs=nx.from_numpy(Xs), Xt=nx.from_numpy(Xt)) np.random.seed(None) -def test_linear_mapping(): +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_linear_mapping(nx): ns = 150 nt = 200 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) - A, b = ot.da.OT_mapping_linear(Xs, Xt) + Xsb, Xtb = nx.from_numpy(Xs, Xt) - Xst = Xs.dot(A) + b + A, b = ot.da.OT_mapping_linear(Xsb, Xtb) + + Xst = nx.to_numpy(nx.dot(Xsb, A) + b) Ct = np.cov(Xt.T) Cst = np.cov(Xst.T) @@ -591,22 +576,26 @@ def test_linear_mapping(): np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) -def test_linear_mapping_class(): +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_linear_mapping_class(nx): ns = 150 nt = 200 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) + Xsb, Xtb = nx.from_numpy(Xs, Xt) + otmap = ot.da.LinearTransport() - otmap.fit(Xs=Xs, Xt=Xt) + otmap.fit(Xs=Xsb, Xt=Xtb) assert hasattr(otmap, "A_") assert hasattr(otmap, "B_") assert hasattr(otmap, "A1_") assert hasattr(otmap, "B1_") - Xst = otmap.transform(Xs=Xs) + Xst = nx.to_numpy(otmap.transform(Xs=Xsb)) Ct = np.cov(Xt.T) Cst = np.cov(Xst.T) @@ -614,7 +603,9 @@ def test_linear_mapping_class(): np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) -def test_jcpot_transport_class(): +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_jcpot_transport_class(nx): """test_jcpot_transport """ @@ -627,6 +618,8 @@ def test_jcpot_transport_class(): Xt, yt = make_data_classif('3gauss2', nt) + Xs1, ys1, Xs2, ys2, Xt, yt = nx.from_numpy(Xs1, ys1, Xs2, ys2, Xt, yt) + Xs = [Xs1, Xs2] ys = [ys1, ys2] @@ -649,19 +642,24 @@ def test_jcpot_transport_class(): for i in range(len(Xs)): # test margin constraints w.r.t. uniform target weights for each coupling matrix assert_allclose( - np.sum(otda.coupling_[i], axis=0), mu_t, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_[i], axis=0)), mu_t, rtol=1e-3, atol=1e-3) # test margin constraints w.r.t. modified source weights for each source domain assert_allclose( - np.dot(otda.log_['D1'][i], np.sum(otda.coupling_[i], axis=1)), otda.proportions_, rtol=1e-3, - atol=1e-3) + nx.to_numpy( + nx.dot(otda.log_['D1'][i], nx.sum(otda.coupling_[i], axis=1)) + ), + nx.to_numpy(otda.proportions_), + rtol=1e-3, + atol=1e-3 + ) # test transform transp_Xs = otda.transform(Xs=Xs) [assert_equal(x.shape, y.shape) for x, y in zip(transp_Xs, Xs)] - Xs_new, _ = make_data_classif('3gauss', ns1 + 1) + Xs_new = nx.from_numpy(make_data_classif('3gauss', ns1 + 1)[0]) transp_Xs_new = otda.transform(Xs_new) # check that the oos method is working @@ -670,15 +668,16 @@ def test_jcpot_transport_class(): # check label propagation transp_yt = otda.transform_labels(ys) assert_equal(transp_yt.shape[0], yt.shape[0]) - assert_equal(transp_yt.shape[1], len(np.unique(ys))) + assert_equal(transp_yt.shape[1], len(np.unique(nx.to_numpy(*ys)))) # check inverse label propagation transp_ys = otda.inverse_transform_labels(yt) - [assert_equal(x.shape[0], y.shape[0]) for x, y in zip(transp_ys, ys)] - [assert_equal(x.shape[1], len(np.unique(y))) for x, y in zip(transp_ys, ys)] + for x, y in zip(transp_ys, ys): + assert_equal(x.shape[0], y.shape[0]) + assert_equal(x.shape[1], len(np.unique(nx.to_numpy(y)))) -def test_jcpot_barycenter(): +def test_jcpot_barycenter(nx): """test_jcpot_barycenter """ @@ -695,19 +694,23 @@ def test_jcpot_barycenter(): Xs1, ys1 = make_data_classif('2gauss_prop', ns1, nz=sigma, p=ps1) Xs2, ys2 = make_data_classif('2gauss_prop', ns2, nz=sigma, p=ps2) - Xt, yt = make_data_classif('2gauss_prop', nt, nz=sigma, p=pt) + Xt, _ = make_data_classif('2gauss_prop', nt, nz=sigma, p=pt) - Xs = [Xs1, Xs2] - ys = [ys1, ys2] + Xs1b, ys1b, Xs2b, ys2b, Xtb = nx.from_numpy(Xs1, ys1, Xs2, ys2, Xt) - prop = ot.bregman.jcpot_barycenter(Xs, ys, Xt, reg=.5, metric='sqeuclidean', + Xsb = [Xs1b, Xs2b] + ysb = [ys1b, ys2b] + + prop = ot.bregman.jcpot_barycenter(Xsb, ysb, Xtb, reg=.5, metric='sqeuclidean', numItermax=10000, stopThr=1e-9, verbose=False, log=False) - np.testing.assert_allclose(prop, [1 - pt, pt], rtol=1e-3, atol=1e-3) + np.testing.assert_allclose(nx.to_numpy(prop), [1 - pt, pt], rtol=1e-3, atol=1e-3) @pytest.mark.skipif(nosklearn, reason="No sklearn available") -def test_emd_laplace_class(): +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_emd_laplace_class(nx): """test_emd_laplace_transport """ ns = 150 @@ -716,6 +719,8 @@ def test_emd_laplace_class(): Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) + Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) + otda = ot.da.EMDLaplaceTransport(reg_lap=0.01, max_iter=1000, tol=1e-9, verbose=False, log=True) # test its computed @@ -732,15 +737,15 @@ def test_emd_laplace_class(): mu_t = unif(nt) assert_allclose( - np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3) assert_allclose( - np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3) # test transform transp_Xs = otda.transform(Xs=Xs) [assert_equal(x.shape, y.shape) for x, y in zip(transp_Xs, Xs)] - Xs_new, _ = make_data_classif('3gauss', ns + 1) + Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0]) transp_Xs_new = otda.transform(Xs_new) # check that the oos method is working @@ -750,7 +755,7 @@ def test_emd_laplace_class(): transp_Xt = otda.inverse_transform(Xt=Xt) assert_equal(transp_Xt.shape, Xt.shape) - Xt_new, _ = make_data_classif('3gauss2', nt + 1) + Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0]) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) # check that the oos method is working @@ -763,9 +768,9 @@ def test_emd_laplace_class(): # check label propagation transp_yt = otda.transform_labels(ys) assert_equal(transp_yt.shape[0], yt.shape[0]) - assert_equal(transp_yt.shape[1], len(np.unique(ys))) + assert_equal(transp_yt.shape[1], len(np.unique(nx.to_numpy(ys)))) # check inverse label propagation transp_ys = otda.inverse_transform_labels(yt) assert_equal(transp_ys.shape[0], ys.shape[0]) - assert_equal(transp_ys.shape[1], len(np.unique(yt))) + assert_equal(transp_ys.shape[1], len(np.unique(nx.to_numpy(yt)))) diff --git a/test/test_dr.py b/test/test_dr.py index 741f2ad..6d7fc9a 100644 --- a/test/test_dr.py +++ b/test/test_dr.py @@ -61,6 +61,28 @@ def test_wda(): @pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)") +def test_wda_low_reg(): + + n_samples = 100 # nb samples in source and target datasets + np.random.seed(0) + + # generate gaussian dataset + xs, ys = ot.datasets.make_data_classif('gaussrot', n_samples) + + n_features_noise = 8 + + xs = np.hstack((xs, np.random.randn(n_samples, n_features_noise))) + + p = 2 + + Pwda, projwda = ot.dr.wda(xs, ys, p, reg=0.01, maxiter=10, sinkhorn_method='sinkhorn_log') + + projwda(xs) + + np.testing.assert_allclose(np.sum(Pwda**2, 0), np.ones(p)) + + +@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)") def test_wda_normalized(): n_samples = 100 # nb samples in source and target datasets diff --git a/test/test_factored.py b/test/test_factored.py new file mode 100644 index 0000000..fd2fd01 --- /dev/null +++ b/test/test_factored.py @@ -0,0 +1,56 @@ +"""Tests for main module ot.weak """ + +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + +import ot +import numpy as np + + +def test_factored_ot(): + # test weak ot solver and identity stationary point + n = 50 + rng = np.random.RandomState(0) + + xs = rng.randn(n, 2) + xt = rng.randn(n, 2) + u = ot.utils.unif(n) + + Ga, Gb, X, log = ot.factored_optimal_transport(xs, xt, u, u, r=10, log=True) + + # check constraints + np.testing.assert_allclose(u, Ga.sum(1)) + np.testing.assert_allclose(u, Gb.sum(0)) + + Ga, Gb, X, log = ot.factored_optimal_transport(xs, xt, u, u, reg=1, r=10, log=True) + + # check constraints + np.testing.assert_allclose(u, Ga.sum(1)) + np.testing.assert_allclose(u, Gb.sum(0)) + + +def test_factored_ot_backends(nx): + # test weak ot solver for different backends + n = 50 + rng = np.random.RandomState(0) + + xs = rng.randn(n, 2) + xt = rng.randn(n, 2) + u = ot.utils.unif(n) + + xs2 = nx.from_numpy(xs) + xt2 = nx.from_numpy(xt) + u2 = nx.from_numpy(u) + + Ga2, Gb2, X2 = ot.factored_optimal_transport(xs2, xt2, u2, u2, r=10) + + # check constraints + np.testing.assert_allclose(u, nx.to_numpy(Ga2).sum(1)) + np.testing.assert_allclose(u, nx.to_numpy(Gb2).sum(0)) + + Ga2, Gb2, X2 = ot.factored_optimal_transport(xs2, xt2, reg=1, r=10, X0=X2) + + # check constraints + np.testing.assert_allclose(u, nx.to_numpy(Ga2).sum(1)) + np.testing.assert_allclose(u, nx.to_numpy(Gb2).sum(0)) diff --git a/test/test_gpu.py b/test/test_gpu.py deleted file mode 100644 index 8e62a74..0000000 --- a/test/test_gpu.py +++ /dev/null @@ -1,106 +0,0 @@ -"""Tests for module gpu for gpu acceleration """ - -# Author: Remi Flamary <remi.flamary@unice.fr> -# -# License: MIT License - -import numpy as np -import ot -import pytest - -try: # test if cudamat installed - import ot.gpu - nogpu = False -except ImportError: - nogpu = True - - -@pytest.mark.skipif(nogpu, reason="No GPU available") -def test_gpu_old_doctests(): - a = [.5, .5] - b = [.5, .5] - M = [[0., 1.], [1., 0.]] - G = ot.sinkhorn(a, b, M, 1) - np.testing.assert_allclose(G, np.array([[0.36552929, 0.13447071], - [0.13447071, 0.36552929]])) - - -@pytest.mark.skipif(nogpu, reason="No GPU available") -def test_gpu_dist(): - - rng = np.random.RandomState(0) - - for n_samples in [50, 100, 500, 1000]: - print(n_samples) - a = rng.rand(n_samples // 4, 100) - b = rng.rand(n_samples, 100) - - M = ot.dist(a.copy(), b.copy()) - M2 = ot.gpu.dist(a.copy(), b.copy()) - - np.testing.assert_allclose(M, M2, rtol=1e-10) - - M2 = ot.gpu.dist(a.copy(), b.copy(), metric='euclidean', to_numpy=False) - - # check raise not implemented wrong metric - with pytest.raises(NotImplementedError): - M2 = ot.gpu.dist(a.copy(), b.copy(), metric='cityblock', to_numpy=False) - - -@pytest.mark.skipif(nogpu, reason="No GPU available") -def test_gpu_sinkhorn(): - - rng = np.random.RandomState(0) - - for n_samples in [50, 100, 500, 1000]: - a = rng.rand(n_samples // 4, 100) - b = rng.rand(n_samples, 100) - - wa = ot.unif(n_samples // 4) - wb = ot.unif(n_samples) - - wb2 = np.random.rand(n_samples, 20) - wb2 /= wb2.sum(0, keepdims=True) - - M = ot.dist(a.copy(), b.copy()) - M2 = ot.gpu.dist(a.copy(), b.copy(), to_numpy=False) - - reg = 1 - - G = ot.sinkhorn(wa, wb, M, reg) - G1 = ot.gpu.sinkhorn(wa, wb, M, reg) - - np.testing.assert_allclose(G1, G, rtol=1e-10) - - # run all on gpu - ot.gpu.sinkhorn(wa, wb, M2, reg, to_numpy=False, log=True) - - # run sinkhorn for multiple targets - ot.gpu.sinkhorn(wa, wb2, M2, reg, to_numpy=False, log=True) - - -@pytest.mark.skipif(nogpu, reason="No GPU available") -def test_gpu_sinkhorn_lpl1(): - - rng = np.random.RandomState(0) - - for n_samples in [50, 100, 500]: - print(n_samples) - a = rng.rand(n_samples // 4, 100) - labels_a = np.random.randint(10, size=(n_samples // 4)) - b = rng.rand(n_samples, 100) - - wa = ot.unif(n_samples // 4) - wb = ot.unif(n_samples) - - M = ot.dist(a.copy(), b.copy()) - M2 = ot.gpu.dist(a.copy(), b.copy(), to_numpy=False) - - reg = 1 - - G = ot.da.sinkhorn_lpl1_mm(wa, labels_a, wb, M, reg) - G1 = ot.gpu.da.sinkhorn_lpl1_mm(wa, labels_a, wb, M, reg) - - np.testing.assert_allclose(G1, G, rtol=1e-10) - - ot.gpu.da.sinkhorn_lpl1_mm(wa, labels_a, wb, M2, reg, to_numpy=False, log=True) diff --git a/test/test_gromov.py b/test/test_gromov.py index 4b995d5..9c85b92 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -3,6 +3,7 @@ # Author: Erwan Vautier <erwan.vautier@gmail.com>
# Nicolas Courty <ncourty@irisa.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
+# CĂ©dric Vincent-Cuaz <cedric.vincent-cuaz@inria.fr>
#
# License: MIT License
@@ -26,6 +27,7 @@ def test_gromov(nx): p = ot.unif(n_samples)
q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
C1 = ot.dist(xs, xs)
C2 = ot.dist(xt, xt)
@@ -33,13 +35,10 @@ def test_gromov(nx): C1 /= C1.max()
C2 /= C2.max()
- C1b = nx.from_numpy(C1)
- C2b = nx.from_numpy(C2)
- pb = nx.from_numpy(p)
- qb = nx.from_numpy(q)
+ C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
- G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', verbose=True)
- Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True))
+ G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', G0=G0, verbose=True)
+ Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True))
# check constraints
np.testing.assert_allclose(G, Gb, atol=1e-06)
@@ -56,9 +55,9 @@ def test_gromov(nx): gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=True)
gwb = nx.to_numpy(gwb)
- gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=False)
+ gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', G0=G0, log=False)
gw_valb = nx.to_numpy(
- ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+ ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False)
)
G = log['T']
@@ -91,6 +90,7 @@ def test_gromov_dtype_device(nx): p = ot.unif(n_samples)
q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
C1 = ot.dist(xs, xs)
C2 = ot.dist(xt, xt)
@@ -101,13 +101,10 @@ def test_gromov_dtype_device(nx): for tp in nx.__type_list__:
print(nx.dtype_device(tp))
- C1b = nx.from_numpy(C1, type_as=tp)
- C2b = nx.from_numpy(C2, type_as=tp)
- pb = nx.from_numpy(p, type_as=tp)
- qb = nx.from_numpy(q, type_as=tp)
+ C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0, type_as=tp)
- Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)
- gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+ Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False)
nx.assert_same_dtype_device(C1b, Gb)
nx.assert_same_dtype_device(C1b, gw_valb)
@@ -123,6 +120,7 @@ def test_gromov_device_tf(): xt = xs[::-1].copy()
p = ot.unif(n_samples)
q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
C1 = ot.dist(xs, xs)
C2 = ot.dist(xt, xt)
C1 /= C1.max()
@@ -130,21 +128,15 @@ def test_gromov_device_tf(): # Check that everything stays on the CPU
with tf.device("/CPU:0"):
- C1b = nx.from_numpy(C1)
- C2b = nx.from_numpy(C2)
- pb = nx.from_numpy(p)
- qb = nx.from_numpy(q)
- Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)
- gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+ C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
+ Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False)
nx.assert_same_dtype_device(C1b, Gb)
nx.assert_same_dtype_device(C1b, gw_valb)
if len(tf.config.list_physical_devices('GPU')) > 0:
# Check that everything happens on the GPU
- C1b = nx.from_numpy(C1)
- C2b = nx.from_numpy(C2)
- pb = nx.from_numpy(p)
- qb = nx.from_numpy(q)
+ C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)
gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
nx.assert_same_dtype_device(C1b, Gb)
@@ -173,25 +165,30 @@ def test_gromov2_gradients(): if torch:
- p1 = torch.tensor(p, requires_grad=True)
- q1 = torch.tensor(q, requires_grad=True)
- C11 = torch.tensor(C1, requires_grad=True)
- C12 = torch.tensor(C2, requires_grad=True)
+ devices = [torch.device("cpu")]
+ if torch.cuda.is_available():
+ devices.append(torch.device("cuda"))
+ for device in devices:
+ p1 = torch.tensor(p, requires_grad=True, device=device)
+ q1 = torch.tensor(q, requires_grad=True, device=device)
+ C11 = torch.tensor(C1, requires_grad=True, device=device)
+ C12 = torch.tensor(C2, requires_grad=True, device=device)
- val = ot.gromov_wasserstein2(C11, C12, p1, q1)
+ val = ot.gromov_wasserstein2(C11, C12, p1, q1)
- val.backward()
+ val.backward()
- assert q1.shape == q1.grad.shape
- assert p1.shape == p1.grad.shape
- assert C11.shape == C11.grad.shape
- assert C12.shape == C12.grad.shape
+ assert val.device == p1.device
+ assert q1.shape == q1.grad.shape
+ assert p1.shape == p1.grad.shape
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
@pytest.skip_backend("jax", reason="test very slow with jax backend")
@pytest.skip_backend("tf", reason="test very slow with tf backend")
def test_entropic_gromov(nx):
- n_samples = 50 # nb samples
+ n_samples = 10 # nb samples
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
@@ -209,10 +206,7 @@ def test_entropic_gromov(nx): C1 /= C1.max()
C2 /= C2.max()
- C1b = nx.from_numpy(C1)
- C2b = nx.from_numpy(C2)
- pb = nx.from_numpy(p)
- qb = nx.from_numpy(q)
+ C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q)
G = ot.gromov.entropic_gromov_wasserstein(
C1, C2, p, q, 'square_loss', epsilon=5e-4, verbose=True)
@@ -228,9 +222,9 @@ def test_entropic_gromov(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov
gw, log = ot.gromov.entropic_gromov_wasserstein2(
- C1, C2, p, q, 'kl_loss', epsilon=1e-2, log=True)
+ C1, C2, p, q, 'kl_loss', max_iter=10, epsilon=1e-2, log=True)
gwb, logb = ot.gromov.entropic_gromov_wasserstein2(
- C1b, C2b, pb, qb, 'kl_loss', epsilon=1e-2, log=True)
+ C1b, C2b, pb, qb, 'kl_loss', max_iter=10, epsilon=1e-2, log=True)
gwb = nx.to_numpy(gwb)
G = log['T']
@@ -251,7 +245,7 @@ def test_entropic_gromov(nx): @pytest.skip_backend("tf", reason="test very slow with tf backend")
def test_entropic_gromov_dtype_device(nx):
# setup
- n_samples = 50 # nb samples
+ n_samples = 5 # nb samples
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
@@ -272,10 +266,7 @@ def test_entropic_gromov_dtype_device(nx): for tp in nx.__type_list__:
print(nx.dtype_device(tp))
- C1b = nx.from_numpy(C1, type_as=tp)
- C2b = nx.from_numpy(C2, type_as=tp)
- pb = nx.from_numpy(p, type_as=tp)
- qb = nx.from_numpy(q, type_as=tp)
+ C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q, type_as=tp)
Gb = ot.gromov.entropic_gromov_wasserstein(
C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True
@@ -289,7 +280,7 @@ def test_entropic_gromov_dtype_device(nx): def test_pointwise_gromov(nx):
- n_samples = 50 # nb samples
+ n_samples = 5 # nb samples
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
@@ -307,10 +298,7 @@ def test_pointwise_gromov(nx): C1 /= C1.max()
C2 /= C2.max()
- C1b = nx.from_numpy(C1)
- C2b = nx.from_numpy(C2)
- pb = nx.from_numpy(p)
- qb = nx.from_numpy(q)
+ C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q)
def loss(x, y):
return np.abs(x - y)
@@ -343,14 +331,12 @@ def test_pointwise_gromov(nx): Gb = nx.to_numpy(nx.todense(Gb))
np.testing.assert_allclose(G, Gb, atol=1e-06)
- np.testing.assert_allclose(float(logb['gw_dist_estimated']), 0.10342276348494964, atol=1e-8)
- np.testing.assert_allclose(float(logb['gw_dist_std']), 0.0015952535464736394, atol=1e-8)
@pytest.skip_backend("tf", reason="test very slow with tf backend")
@pytest.skip_backend("jax", reason="test very slow with jax backend")
def test_sampled_gromov(nx):
- n_samples = 50 # nb samples
+ n_samples = 5 # nb samples
mu_s = np.array([0, 0], dtype=np.float64)
cov_s = np.array([[1, 0], [0, 1]], dtype=np.float64)
@@ -368,10 +354,7 @@ def test_sampled_gromov(nx): C1 /= C1.max()
C2 /= C2.max()
- C1b = nx.from_numpy(C1)
- C2b = nx.from_numpy(C2)
- pb = nx.from_numpy(p)
- qb = nx.from_numpy(q)
+ C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q)
def loss(x, y):
return np.abs(x - y)
@@ -380,9 +363,9 @@ def test_sampled_gromov(nx): return nx.abs(x - y)
G, log = ot.gromov.sampled_gromov_wasserstein(
- C1, C2, p, q, loss, max_iter=100, epsilon=1, log=True, verbose=True, random_state=42)
+ C1, C2, p, q, loss, max_iter=20, nb_samples_grad=2, epsilon=1, log=True, verbose=True, random_state=42)
Gb, logb = ot.gromov.sampled_gromov_wasserstein(
- C1b, C2b, pb, qb, lossb, max_iter=100, epsilon=1, log=True, verbose=True, random_state=42)
+ C1b, C2b, pb, qb, lossb, max_iter=20, nb_samples_grad=2, epsilon=1, log=True, verbose=True, random_state=42)
Gb = nx.to_numpy(Gb)
# check constraints
@@ -392,13 +375,10 @@ def test_sampled_gromov(nx): np.testing.assert_allclose(
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
- np.testing.assert_allclose(float(logb['gw_dist_estimated']), 0.05679474884977278, atol=1e-08)
- np.testing.assert_allclose(float(logb['gw_dist_std']), 0.0005986592106971995, atol=1e-08)
-
def test_gromov_barycenter(nx):
- ns = 10
- nt = 20
+ ns = 5
+ nt = 8
Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
@@ -410,19 +390,15 @@ def test_gromov_barycenter(nx): n_samples = 3
p = ot.unif(n_samples)
- C1b = nx.from_numpy(C1)
- C2b = nx.from_numpy(C2)
- p1b = nx.from_numpy(p1)
- p2b = nx.from_numpy(p2)
- pb = nx.from_numpy(p)
+ C1b, C2b, p1b, p2b, pb = nx.from_numpy(C1, C2, p1, p2, p)
Cb = ot.gromov.gromov_barycenters(
n_samples, [C1, C2], [p1, p2], p, [.5, .5],
- 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42
+ 'square_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42
)
Cbb = nx.to_numpy(ot.gromov.gromov_barycenters(
n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
- 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42
+ 'square_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42
))
np.testing.assert_allclose(Cb, Cbb, atol=1e-06)
np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples))
@@ -430,15 +406,15 @@ def test_gromov_barycenter(nx): # test of gromov_barycenters with `log` on
Cb_, err_ = ot.gromov.gromov_barycenters(
n_samples, [C1, C2], [p1, p2], p, [.5, .5],
- 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ 'square_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42, log=True
)
Cbb_, errb_ = ot.gromov.gromov_barycenters(
n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
- 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ 'square_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42, log=True
)
Cbb_ = nx.to_numpy(Cbb_)
np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06)
- np.testing.assert_array_almost_equal(err_['err'], errb_['err'])
+ np.testing.assert_array_almost_equal(err_['err'], nx.to_numpy(*errb_['err']))
np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples))
Cb2 = ot.gromov.gromov_barycenters(
@@ -455,22 +431,22 @@ def test_gromov_barycenter(nx): # test of gromov_barycenters with `log` on
Cb2_, err2_ = ot.gromov.gromov_barycenters(
n_samples, [C1, C2], [p1, p2], p, [.5, .5],
- 'kl_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ 'kl_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42, log=True
)
Cb2b_, err2b_ = ot.gromov.gromov_barycenters(
n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
- 'kl_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ 'kl_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42, log=True
)
Cb2b_ = nx.to_numpy(Cb2b_)
np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06)
- np.testing.assert_array_almost_equal(err2_['err'], err2_['err'])
+ np.testing.assert_array_almost_equal(err2_['err'], nx.to_numpy(*err2b_['err']))
np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples))
@pytest.mark.filterwarnings("ignore:divide")
def test_gromov_entropic_barycenter(nx):
- ns = 10
- nt = 20
+ ns = 5
+ nt = 10
Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
@@ -482,11 +458,7 @@ def test_gromov_entropic_barycenter(nx): n_samples = 2
p = ot.unif(n_samples)
- C1b = nx.from_numpy(C1)
- C2b = nx.from_numpy(C2)
- p1b = nx.from_numpy(p1)
- p2b = nx.from_numpy(p2)
- pb = nx.from_numpy(p)
+ C1b, C2b, p1b, p2b, pb = nx.from_numpy(C1, C2, p1, p2, p)
Cb = ot.gromov.entropic_gromov_barycenters(
n_samples, [C1, C2], [p1, p2], p, [.5, .5],
@@ -510,7 +482,7 @@ def test_gromov_entropic_barycenter(nx): )
Cbb_ = nx.to_numpy(Cbb_)
np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06)
- np.testing.assert_array_almost_equal(err_['err'], errb_['err'])
+ np.testing.assert_array_almost_equal(err_['err'], nx.to_numpy(*errb_['err']))
np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples))
Cb2 = ot.gromov.entropic_gromov_barycenters(
@@ -535,12 +507,12 @@ def test_gromov_entropic_barycenter(nx): )
Cb2b_ = nx.to_numpy(Cb2b_)
np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06)
- np.testing.assert_array_almost_equal(err2_['err'], err2_['err'])
+ np.testing.assert_array_almost_equal(err2_['err'], nx.to_numpy(*err2b_['err']))
np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples))
def test_fgw(nx):
- n_samples = 50 # nb samples
+ n_samples = 20 # nb samples
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
@@ -554,6 +526,7 @@ def test_fgw(nx): p = ot.unif(n_samples)
q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
C1 = ot.dist(xs, xs)
C2 = ot.dist(xt, xt)
@@ -564,14 +537,10 @@ def test_fgw(nx): M = ot.dist(ys, yt)
M /= M.max()
- Mb = nx.from_numpy(M)
- C1b = nx.from_numpy(C1)
- C2b = nx.from_numpy(C2)
- pb = nx.from_numpy(p)
- qb = nx.from_numpy(q)
+ Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0)
- G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True)
- Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True)
+ G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True)
+ Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, G0=G0b, log=True)
Gb = nx.to_numpy(Gb)
# check constraints
@@ -586,8 +555,8 @@ def test_fgw(nx): np.testing.assert_allclose(
Gb, np.flipud(Id), atol=1e-04) # cf convergence gromov
- fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True)
- fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True)
+ fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', G0=None, alpha=0.5, log=True)
+ fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', G0=G0b, alpha=0.5, log=True)
fgwb = nx.to_numpy(fgwb)
G = log['T']
@@ -605,7 +574,7 @@ def test_fgw(nx): def test_fgw2_gradients():
- n_samples = 50 # nb samples
+ n_samples = 20 # nb samples
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
@@ -626,28 +595,33 @@ def test_fgw2_gradients(): if torch:
- p1 = torch.tensor(p, requires_grad=True)
- q1 = torch.tensor(q, requires_grad=True)
- C11 = torch.tensor(C1, requires_grad=True)
- C12 = torch.tensor(C2, requires_grad=True)
- M1 = torch.tensor(M, requires_grad=True)
+ devices = [torch.device("cpu")]
+ if torch.cuda.is_available():
+ devices.append(torch.device("cuda"))
+ for device in devices:
+ p1 = torch.tensor(p, requires_grad=True, device=device)
+ q1 = torch.tensor(q, requires_grad=True, device=device)
+ C11 = torch.tensor(C1, requires_grad=True, device=device)
+ C12 = torch.tensor(C2, requires_grad=True, device=device)
+ M1 = torch.tensor(M, requires_grad=True, device=device)
- val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1)
+ val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1)
- val.backward()
+ val.backward()
- assert q1.shape == q1.grad.shape
- assert p1.shape == p1.grad.shape
- assert C11.shape == C11.grad.shape
- assert C12.shape == C12.grad.shape
- assert M1.shape == M1.grad.shape
+ assert val.device == p1.device
+ assert q1.shape == q1.grad.shape
+ assert p1.shape == p1.grad.shape
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
+ assert M1.shape == M1.grad.shape
def test_fgw_barycenter(nx):
np.random.seed(42)
- ns = 50
- nt = 60
+ ns = 10
+ nt = 20
Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
@@ -661,13 +635,7 @@ def test_fgw_barycenter(nx): n_samples = 3
p = ot.unif(n_samples)
- ysb = nx.from_numpy(ys)
- ytb = nx.from_numpy(yt)
- C1b = nx.from_numpy(C1)
- C2b = nx.from_numpy(C2)
- p1b = nx.from_numpy(p1)
- p2b = nx.from_numpy(p2)
- pb = nx.from_numpy(p)
+ ysb, ytb, C1b, C2b, p1b, p2b, pb = nx.from_numpy(ys, yt, C1, C2, p1, p2, p)
Xb, Cb = ot.gromov.fgw_barycenters(
n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, fixed_structure=False,
@@ -698,3 +666,523 @@ def test_fgw_barycenter(nx): Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb)
np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1]))
+
+
+def test_gromov_wasserstein_linear_unmixing(nx):
+ n = 4
+
+ X1, y1 = ot.datasets.make_data_classif('3gauss', n, random_state=42)
+ X2, y2 = ot.datasets.make_data_classif('3gauss2', n, random_state=42)
+
+ C1 = ot.dist(X1)
+ C2 = ot.dist(X2)
+ Cdict = np.stack([C1, C2])
+ p = ot.unif(n)
+
+ C1b, C2b, Cdictb, pb = nx.from_numpy(C1, C2, Cdict, p)
+
+ tol = 10**(-5)
+ # Tests without regularization
+ reg = 0.
+ unmixing1, C1_emb, OT, reconstruction1 = ot.gromov.gromov_wasserstein_linear_unmixing(
+ C1, Cdict, reg=reg, p=p, q=p,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+
+ unmixing1b, C1b_emb, OTb, reconstruction1b = ot.gromov.gromov_wasserstein_linear_unmixing(
+ C1b, Cdictb, reg=reg, p=None, q=None,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+
+ unmixing2, C2_emb, OT, reconstruction2 = ot.gromov.gromov_wasserstein_linear_unmixing(
+ C2, Cdict, reg=reg, p=None, q=None,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+
+ unmixing2b, C2b_emb, OTb, reconstruction2b = ot.gromov.gromov_wasserstein_linear_unmixing(
+ C2b, Cdictb, reg=reg, p=pb, q=pb,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+
+ np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=5e-06)
+ np.testing.assert_allclose(unmixing1, [1., 0.], atol=5e-01)
+ np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=5e-06)
+ np.testing.assert_allclose(unmixing2, [0., 1.], atol=5e-01)
+ np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-06)
+ np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-06)
+ np.testing.assert_allclose(reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06)
+ np.testing.assert_allclose(reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06)
+ np.testing.assert_allclose(C1b_emb.shape, (n, n))
+ np.testing.assert_allclose(C2b_emb.shape, (n, n))
+
+ # Tests with regularization
+
+ reg = 0.001
+ unmixing1, C1_emb, OT, reconstruction1 = ot.gromov.gromov_wasserstein_linear_unmixing(
+ C1, Cdict, reg=reg, p=p, q=p,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+
+ unmixing1b, C1b_emb, OTb, reconstruction1b = ot.gromov.gromov_wasserstein_linear_unmixing(
+ C1b, Cdictb, reg=reg, p=None, q=None,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+
+ unmixing2, C2_emb, OT, reconstruction2 = ot.gromov.gromov_wasserstein_linear_unmixing(
+ C2, Cdict, reg=reg, p=None, q=None,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+
+ unmixing2b, C2b_emb, OTb, reconstruction2b = ot.gromov.gromov_wasserstein_linear_unmixing(
+ C2b, Cdictb, reg=reg, p=pb, q=pb,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+
+ np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06)
+ np.testing.assert_allclose(unmixing1, [1., 0.], atol=1e-01)
+ np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=1e-06)
+ np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01)
+ np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-06)
+ np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-06)
+ np.testing.assert_allclose(reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06)
+ np.testing.assert_allclose(reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06)
+ np.testing.assert_allclose(C1b_emb.shape, (n, n))
+ np.testing.assert_allclose(C2b_emb.shape, (n, n))
+
+
+def test_gromov_wasserstein_dictionary_learning(nx):
+
+ # create dataset composed from 2 structures which are repeated 5 times
+ shape = 4
+ n_samples = 2
+ n_atoms = 2
+ projection = 'nonnegative_symmetric'
+ X1, y1 = ot.datasets.make_data_classif('3gauss', shape, random_state=42)
+ X2, y2 = ot.datasets.make_data_classif('3gauss2', shape, random_state=42)
+ C1 = ot.dist(X1)
+ C2 = ot.dist(X2)
+ Cs = [C1.copy() for _ in range(n_samples // 2)] + [C2.copy() for _ in range(n_samples // 2)]
+ ps = [ot.unif(shape) for _ in range(n_samples)]
+ q = ot.unif(shape)
+
+ # Provide initialization for the graph dictionary of shape (n_atoms, shape, shape)
+ # following the same procedure than implemented in gromov_wasserstein_dictionary_learning.
+ dataset_means = [C.mean() for C in Cs]
+ np.random.seed(0)
+ Cdict_init = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(n_atoms, shape, shape))
+
+ if projection == 'nonnegative_symmetric':
+ Cdict_init = 0.5 * (Cdict_init + Cdict_init.transpose((0, 2, 1)))
+ Cdict_init[Cdict_init < 0.] = 0.
+
+ Csb = nx.from_numpy(*Cs)
+ psb = nx.from_numpy(*ps)
+ qb, Cdict_initb = nx.from_numpy(q, Cdict_init)
+
+ # Test: compare reconstruction error using initial dictionary and dictionary learned using this initialization
+ # > Compute initial reconstruction of samples on this random dictionary without backend
+ use_adam_optimizer = True
+ verbose = False
+ tol = 10**(-5)
+ epochs = 1
+
+ initial_total_reconstruction = 0
+ for i in range(n_samples):
+ _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing(
+ Cs[i], Cdict_init, p=ps[i], q=q, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50
+ )
+ initial_total_reconstruction += reconstruction
+
+ # > Learn the dictionary using this init
+ Cdict, log = ot.gromov.gromov_wasserstein_dictionary_learning(
+ Cs, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict_init,
+ epochs=epochs, batch_size=2 * n_samples, learning_rate=1., reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50,
+ projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+ # > Compute reconstruction of samples on learned dictionary without backend
+ total_reconstruction = 0
+ for i in range(n_samples):
+ _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing(
+ Cs[i], Cdict, p=None, q=None, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50
+ )
+ total_reconstruction += reconstruction
+
+ np.testing.assert_array_less(total_reconstruction, initial_total_reconstruction)
+
+ # Test: Perform same experiments after going through backend
+
+ Cdictb, log = ot.gromov.gromov_wasserstein_dictionary_learning(
+ Csb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdict_initb,
+ epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50,
+ projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+ # Compute reconstruction of samples on learned dictionary
+ total_reconstruction_b = 0
+ for i in range(n_samples):
+ _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing(
+ Csb[i], Cdictb, p=psb[i], q=qb, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50
+ )
+ total_reconstruction_b += reconstruction
+
+ total_reconstruction_b = nx.to_numpy(total_reconstruction_b)
+ np.testing.assert_array_less(total_reconstruction_b, initial_total_reconstruction)
+ np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05)
+ np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05)
+ np.testing.assert_allclose(Cdict, nx.to_numpy(Cdictb), atol=1e-03)
+
+ # Test: Perform same comparison without providing the initial dictionary being an optional input
+ # knowing than the initialization scheme is the same than implemented to set the benchmarked initialization.
+ np.random.seed(0)
+ Cdict_bis, log = ot.gromov.gromov_wasserstein_dictionary_learning(
+ Cs, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None,
+ epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50,
+ projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+ # > Compute reconstruction of samples on learned dictionary
+ total_reconstruction_bis = 0
+ for i in range(n_samples):
+ _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing(
+ Cs[i], Cdict_bis, p=ps[i], q=q, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50
+ )
+ total_reconstruction_bis += reconstruction
+
+ np.testing.assert_allclose(total_reconstruction_bis, total_reconstruction, atol=1e-05)
+
+ # Test: Same after going through backend
+ np.random.seed(0)
+ Cdictb_bis, log = ot.gromov.gromov_wasserstein_dictionary_learning(
+ Csb, D=n_atoms, nt=shape, ps=psb, q=qb, Cdict_init=None,
+ epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50,
+ projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+ # > Compute reconstruction of samples on learned dictionary
+ total_reconstruction_b_bis = 0
+ for i in range(n_samples):
+ _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing(
+ Csb[i], Cdictb_bis, p=None, q=None, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50
+ )
+ total_reconstruction_b_bis += reconstruction
+
+ total_reconstruction_b_bis = nx.to_numpy(total_reconstruction_b_bis)
+ np.testing.assert_allclose(total_reconstruction_b_bis, total_reconstruction_b, atol=1e-05)
+ np.testing.assert_allclose(Cdict_bis, nx.to_numpy(Cdictb_bis), atol=1e-03)
+
+ # Test: Perform same comparison without providing the initial dictionary being an optional input
+ # and testing other optimization settings untested until now.
+ # We pass previously estimated dictionaries to speed up the process.
+ use_adam_optimizer = False
+ verbose = True
+ use_log = True
+
+ np.random.seed(0)
+ Cdict_bis2, log = ot.gromov.gromov_wasserstein_dictionary_learning(
+ Cs, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict,
+ epochs=epochs, batch_size=n_samples, learning_rate=10., reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50,
+ projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+ # > Compute reconstruction of samples on learned dictionary
+ total_reconstruction_bis2 = 0
+ for i in range(n_samples):
+ _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing(
+ Cs[i], Cdict_bis2, p=ps[i], q=q, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50
+ )
+ total_reconstruction_bis2 += reconstruction
+
+ np.testing.assert_array_less(total_reconstruction_bis2, total_reconstruction)
+
+ # Test: Same after going through backend
+ np.random.seed(0)
+ Cdictb_bis2, log = ot.gromov.gromov_wasserstein_dictionary_learning(
+ Csb, D=n_atoms, nt=shape, ps=psb, q=qb, Cdict_init=Cdictb,
+ epochs=epochs, batch_size=n_samples, learning_rate=10., reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50,
+ projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+ # > Compute reconstruction of samples on learned dictionary
+ total_reconstruction_b_bis2 = 0
+ for i in range(n_samples):
+ _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing(
+ Csb[i], Cdictb_bis2, p=psb[i], q=qb, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50
+ )
+ total_reconstruction_b_bis2 += reconstruction
+
+ total_reconstruction_b_bis2 = nx.to_numpy(total_reconstruction_b_bis2)
+ np.testing.assert_allclose(total_reconstruction_b_bis2, total_reconstruction_bis2, atol=1e-05)
+
+
+def test_fused_gromov_wasserstein_linear_unmixing(nx):
+
+ n = 4
+ X1, y1 = ot.datasets.make_data_classif('3gauss', n, random_state=42)
+ X2, y2 = ot.datasets.make_data_classif('3gauss2', n, random_state=42)
+ F, y = ot.datasets.make_data_classif('3gauss', n, random_state=42)
+
+ C1 = ot.dist(X1)
+ C2 = ot.dist(X2)
+ Cdict = np.stack([C1, C2])
+ Ydict = np.stack([F, F])
+ p = ot.unif(n)
+
+ C1b, C2b, Fb, Cdictb, Ydictb, pb = nx.from_numpy(C1, C2, F, Cdict, Ydict, p)
+
+ # Tests without regularization
+ reg = 0.
+
+ unmixing1, C1_emb, Y1_emb, OT, reconstruction1 = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ C1, F, Cdict, Ydict, p=p, q=p, alpha=0.5, reg=reg,
+ tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50
+ )
+
+ unmixing1b, C1b_emb, Y1b_emb, OTb, reconstruction1b = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ C1b, Fb, Cdictb, Ydictb, p=None, q=None, alpha=0.5, reg=reg,
+ tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50
+ )
+
+ unmixing2, C2_emb, Y2_emb, OT, reconstruction2 = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ C2, F, Cdict, Ydict, p=None, q=None, alpha=0.5, reg=reg,
+ tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50
+ )
+
+ unmixing2b, C2b_emb, Y2b_emb, OTb, reconstruction2b = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ C2b, Fb, Cdictb, Ydictb, p=pb, q=pb, alpha=0.5, reg=reg,
+ tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50
+ )
+
+ np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=4e-06)
+ np.testing.assert_allclose(unmixing1, [1., 0.], atol=4e-01)
+ np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=4e-06)
+ np.testing.assert_allclose(unmixing2, [0., 1.], atol=4e-01)
+ np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-03)
+ np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-03)
+ np.testing.assert_allclose(Y1_emb, nx.to_numpy(Y1b_emb), atol=1e-03)
+ np.testing.assert_allclose(Y2_emb, nx.to_numpy(Y2b_emb), atol=1e-03)
+ np.testing.assert_allclose(reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06)
+ np.testing.assert_allclose(reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06)
+ np.testing.assert_allclose(C1b_emb.shape, (n, n))
+ np.testing.assert_allclose(C2b_emb.shape, (n, n))
+
+ # Tests with regularization
+ reg = 0.001
+
+ unmixing1, C1_emb, Y1_emb, OT, reconstruction1 = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ C1, F, Cdict, Ydict, p=p, q=p, alpha=0.5, reg=reg,
+ tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50
+ )
+
+ unmixing1b, C1b_emb, Y1b_emb, OTb, reconstruction1b = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ C1b, Fb, Cdictb, Ydictb, p=None, q=None, alpha=0.5, reg=reg,
+ tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50
+ )
+
+ unmixing2, C2_emb, Y2_emb, OT, reconstruction2 = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ C2, F, Cdict, Ydict, p=None, q=None, alpha=0.5, reg=reg,
+ tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50
+ )
+
+ unmixing2b, C2b_emb, Y2b_emb, OTb, reconstruction2b = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ C2b, Fb, Cdictb, Ydictb, p=pb, q=pb, alpha=0.5, reg=reg,
+ tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50
+ )
+
+ np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06)
+ np.testing.assert_allclose(unmixing1, [1., 0.], atol=1e-01)
+ np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=1e-06)
+ np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01)
+ np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-03)
+ np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-03)
+ np.testing.assert_allclose(Y1_emb, nx.to_numpy(Y1b_emb), atol=1e-03)
+ np.testing.assert_allclose(Y2_emb, nx.to_numpy(Y2b_emb), atol=1e-03)
+ np.testing.assert_allclose(reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06)
+ np.testing.assert_allclose(reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06)
+ np.testing.assert_allclose(C1b_emb.shape, (n, n))
+ np.testing.assert_allclose(C2b_emb.shape, (n, n))
+
+
+def test_fused_gromov_wasserstein_dictionary_learning(nx):
+
+ # create dataset composed from 2 structures which are repeated 5 times
+ shape = 4
+ n_samples = 2
+ n_atoms = 2
+ projection = 'nonnegative_symmetric'
+ X1, y1 = ot.datasets.make_data_classif('3gauss', shape, random_state=42)
+ X2, y2 = ot.datasets.make_data_classif('3gauss2', shape, random_state=42)
+ F, y = ot.datasets.make_data_classif('3gauss', shape, random_state=42)
+
+ C1 = ot.dist(X1)
+ C2 = ot.dist(X2)
+ Cs = [C1.copy() for _ in range(n_samples // 2)] + [C2.copy() for _ in range(n_samples // 2)]
+ Ys = [F.copy() for _ in range(n_samples)]
+ ps = [ot.unif(shape) for _ in range(n_samples)]
+ q = ot.unif(shape)
+
+ # Provide initialization for the graph dictionary of shape (n_atoms, shape, shape)
+ # following the same procedure than implemented in gromov_wasserstein_dictionary_learning.
+ dataset_structure_means = [C.mean() for C in Cs]
+ np.random.seed(0)
+ Cdict_init = np.random.normal(loc=np.mean(dataset_structure_means), scale=np.std(dataset_structure_means), size=(n_atoms, shape, shape))
+ if projection == 'nonnegative_symmetric':
+ Cdict_init = 0.5 * (Cdict_init + Cdict_init.transpose((0, 2, 1)))
+ Cdict_init[Cdict_init < 0.] = 0.
+ dataset_feature_means = np.stack([Y.mean(axis=0) for Y in Ys])
+ Ydict_init = np.random.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(n_atoms, shape, 2))
+
+ Csb = nx.from_numpy(*Cs)
+ Ysb = nx.from_numpy(*Ys)
+ psb = nx.from_numpy(*ps)
+ qb, Cdict_initb, Ydict_initb = nx.from_numpy(q, Cdict_init, Ydict_init)
+
+ # Test: Compute initial reconstruction of samples on this random dictionary
+ alpha = 0.5
+ use_adam_optimizer = True
+ verbose = False
+ tol = 1e-05
+ epochs = 1
+
+ initial_total_reconstruction = 0
+ for i in range(n_samples):
+ _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ Cs[i], Ys[i], Cdict_init, Ydict_init, p=ps[i], q=q,
+ alpha=alpha, reg=0., tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50
+ )
+ initial_total_reconstruction += reconstruction
+
+ # > Learn a dictionary using this given initialization and check that the reconstruction loss
+ # on the learned dictionary is lower than the one using its initialization.
+ Cdict, Ydict, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning(
+ Cs, Ys, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict_init, Ydict_init=Ydict_init,
+ epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50,
+ projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+ # > Compute reconstruction of samples on learned dictionary
+ total_reconstruction = 0
+ for i in range(n_samples):
+ _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ Cs[i], Ys[i], Cdict, Ydict, p=None, q=None, alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50
+ )
+ total_reconstruction += reconstruction
+ # Compare both
+ np.testing.assert_array_less(total_reconstruction, initial_total_reconstruction)
+
+ # Test: Perform same experiments after going through backend
+
+ Cdictb, Ydictb, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning(
+ Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdict_initb, Ydict_init=Ydict_initb,
+ epochs=epochs, batch_size=2 * n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50,
+ projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+ # > Compute reconstruction of samples on learned dictionary
+ total_reconstruction_b = 0
+ for i in range(n_samples):
+ _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ Csb[i], Ysb[i], Cdictb, Ydictb, p=psb[i], q=qb, alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50
+ )
+ total_reconstruction_b += reconstruction
+
+ total_reconstruction_b = nx.to_numpy(total_reconstruction_b)
+ np.testing.assert_array_less(total_reconstruction_b, initial_total_reconstruction)
+ np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05)
+ np.testing.assert_allclose(Cdict, nx.to_numpy(Cdictb), atol=1e-03)
+ np.testing.assert_allclose(Ydict, nx.to_numpy(Ydictb), atol=1e-03)
+
+ # Test: Perform similar experiment without providing the initial dictionary being an optional input
+ np.random.seed(0)
+ Cdict_bis, Ydict_bis, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning(
+ Cs, Ys, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, Ydict_init=None,
+ epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50,
+ projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+ # > Compute reconstruction of samples on learned dictionary
+ total_reconstruction_bis = 0
+ for i in range(n_samples):
+ _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ Cs[i], Ys[i], Cdict_bis, Ydict_bis, p=ps[i], q=q, alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50
+ )
+ total_reconstruction_bis += reconstruction
+
+ np.testing.assert_allclose(total_reconstruction_bis, total_reconstruction, atol=1e-05)
+
+ # > Same after going through backend
+ np.random.seed(0)
+ Cdictb_bis, Ydictb_bis, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning(
+ Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, Ydict_init=None,
+ epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50,
+ projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+
+ # > Compute reconstruction of samples on learned dictionary
+ total_reconstruction_b_bis = 0
+ for i in range(n_samples):
+ _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ Csb[i], Ysb[i], Cdictb_bis, Ydictb_bis, p=psb[i], q=qb, alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50
+ )
+ total_reconstruction_b_bis += reconstruction
+
+ total_reconstruction_b_bis = nx.to_numpy(total_reconstruction_b_bis)
+ np.testing.assert_allclose(total_reconstruction_b_bis, total_reconstruction_b, atol=1e-05)
+
+ # Test: without using adam optimizer, with log and verbose set to True
+ use_adam_optimizer = False
+ verbose = True
+ use_log = True
+
+ # > Experiment providing previously estimated dictionary to speed up the test compared to providing initial random init.
+ np.random.seed(0)
+ Cdict_bis2, Ydict_bis2, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning(
+ Cs, Ys, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict, Ydict_init=Ydict,
+ epochs=epochs, batch_size=n_samples, learning_rate_C=10., learning_rate_Y=10., alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50,
+ projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+ # > Compute reconstruction of samples on learned dictionary
+ total_reconstruction_bis2 = 0
+ for i in range(n_samples):
+ _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ Cs[i], Ys[i], Cdict_bis2, Ydict_bis2, p=ps[i], q=q, alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50
+ )
+ total_reconstruction_bis2 += reconstruction
+
+ np.testing.assert_array_less(total_reconstruction_bis2, total_reconstruction)
+
+ # > Same after going through backend
+ np.random.seed(0)
+ Cdictb_bis2, Ydictb_bis2, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning(
+ Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdictb, Ydict_init=Ydictb,
+ epochs=epochs, batch_size=n_samples, learning_rate_C=10., learning_rate_Y=10., alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50,
+ projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+
+ # > Compute reconstruction of samples on learned dictionary
+ total_reconstruction_b_bis2 = 0
+ for i in range(n_samples):
+ _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ Csb[i], Ysb[i], Cdictb_bis2, Ydictb_bis2, p=None, q=None, alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50
+ )
+ total_reconstruction_b_bis2 += reconstruction
+
+ # > Compare results with/without backend
+ total_reconstruction_b_bis2 = nx.to_numpy(total_reconstruction_b_bis2)
+ np.testing.assert_allclose(total_reconstruction_bis2, total_reconstruction_b_bis2, atol=1e-05)
diff --git a/test/test_optim.py b/test/test_optim.py index 41f9cbe..67e9d13 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -32,9 +32,7 @@ def test_conditional_gradient(nx): def fb(G): return 0.5 * nx.sum(G ** 2) - ab = nx.from_numpy(a) - bb = nx.from_numpy(b) - Mb = nx.from_numpy(M, type_as=ab) + ab, bb, Mb = nx.from_numpy(a, b, M) reg = 1e-1 @@ -74,9 +72,7 @@ def test_conditional_gradient_itermax(nx): def fb(G): return 0.5 * nx.sum(G ** 2) - ab = nx.from_numpy(a) - bb = nx.from_numpy(b) - Mb = nx.from_numpy(M, type_as=ab) + ab, bb, Mb = nx.from_numpy(a, b, M) reg = 1e-1 @@ -118,9 +114,7 @@ def test_generalized_conditional_gradient(nx): reg1 = 1e-3 reg2 = 1e-1 - ab = nx.from_numpy(a) - bb = nx.from_numpy(b) - Mb = nx.from_numpy(M, type_as=ab) + ab, bb, Mb = nx.from_numpy(a, b, M) G, log = ot.optim.gcg(a, b, M, reg1, reg2, f, df, verbose=True, log=True) Gb, log = ot.optim.gcg(ab, bb, Mb, reg1, reg2, fb, df, verbose=True, log=True) @@ -142,9 +136,12 @@ def test_line_search_armijo(nx): pk = np.array([[-0.25, 0.25], [0.25, -0.25]]) gfk = np.array([[23.04273441, 23.0449082], [23.04273441, 23.0449082]]) old_fval = -123 + + xkb, pkb, gfkb = nx.from_numpy(xk, pk, gfk) + # Should not throw an exception and return 0. for alpha alpha, a, b = ot.optim.line_search_armijo( - lambda x: 1, nx.from_numpy(xk), nx.from_numpy(pk), nx.from_numpy(gfk), old_fval + lambda x: 1, xkb, pkb, gfkb, old_fval ) alpha_np, anp, bnp = ot.optim.line_search_armijo( lambda x: 1, xk, pk, gfk, old_fval diff --git a/test/test_ot.py b/test/test_ot.py index 53edf4f..bf832f6 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -47,8 +47,7 @@ def test_emd_backends(nx): G = ot.emd(a, a, M) - ab = nx.from_numpy(a) - Mb = nx.from_numpy(M) + ab, Mb = nx.from_numpy(a, M) Gb = ot.emd(ab, ab, Mb) @@ -68,8 +67,7 @@ def test_emd2_backends(nx): val = ot.emd2(a, a, M) - ab = nx.from_numpy(a) - Mb = nx.from_numpy(M) + ab, Mb = nx.from_numpy(a, M) valb = ot.emd2(ab, ab, Mb) @@ -90,8 +88,7 @@ def test_emd_emd2_types_devices(nx): for tp in nx.__type_list__: print(nx.dtype_device(tp)) - ab = nx.from_numpy(a, type_as=tp) - Mb = nx.from_numpy(M, type_as=tp) + ab, Mb = nx.from_numpy(a, M, type_as=tp) Gb = ot.emd(ab, ab, Mb) @@ -117,8 +114,7 @@ def test_emd_emd2_devices_tf(): # Check that everything stays on the CPU with tf.device("/CPU:0"): - ab = nx.from_numpy(a) - Mb = nx.from_numpy(M) + ab, Mb = nx.from_numpy(a, M) Gb = ot.emd(ab, ab, Mb) w = ot.emd2(ab, ab, Mb) nx.assert_same_dtype_device(Mb, Gb) @@ -126,8 +122,7 @@ def test_emd_emd2_devices_tf(): if len(tf.config.list_physical_devices('GPU')) > 0: # Check that everything happens on the GPU - ab = nx.from_numpy(a) - Mb = nx.from_numpy(M) + ab, Mb = nx.from_numpy(a, M) Gb = ot.emd(ab, ab, Mb) w = ot.emd2(ab, ab, Mb) nx.assert_same_dtype_device(Mb, Gb) @@ -152,7 +147,7 @@ def test_emd2_gradients(): b1 = torch.tensor(a, requires_grad=True) M1 = torch.tensor(M, requires_grad=True) - val = ot.emd2(a1, b1, M1) + val, log = ot.emd2(a1, b1, M1, log=True) val.backward() @@ -160,6 +155,12 @@ def test_emd2_gradients(): assert b1.shape == b1.grad.shape assert M1.shape == M1.grad.shape + assert np.allclose(a1.grad.cpu().detach().numpy(), + log['u'].cpu().detach().numpy() - log['u'].cpu().detach().numpy().mean()) + + assert np.allclose(b1.grad.cpu().detach().numpy(), + log['v'].cpu().detach().numpy() - log['v'].cpu().detach().numpy().mean()) + # Testing for bug #309, checking for scaling of gradient a2 = torch.tensor(a, requires_grad=True) b2 = torch.tensor(a, requires_grad=True) @@ -232,7 +233,7 @@ def test_emd2_multi(): # Gaussian distributions a = gauss(n, m=20, s=5) # m= mean, s= std - ls = np.arange(20, 500, 20) + ls = np.arange(20, 500, 100) nb = len(ls) b = np.zeros((n, nb)) for i in range(nb): @@ -302,6 +303,23 @@ def test_free_support_barycenter(): np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7) +def test_free_support_barycenter_backends(nx): + + measures_locations = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))] + measures_weights = [np.array([1.]), np.array([1.])] + X_init = np.array([-12.]).reshape((1, 1)) + + X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init) + + measures_locations2 = nx.from_numpy(*measures_locations) + measures_weights2 = nx.from_numpy(*measures_weights) + X_init2 = nx.from_numpy(X_init) + + X2 = ot.lp.free_support_barycenter(measures_locations2, measures_weights2, X_init2) + + np.testing.assert_allclose(X, nx.to_numpy(X2)) + + @pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available") def test_lp_barycenter_cvxopt(): a1 = np.array([1.0, 0, 0])[:, None] diff --git a/test/test_sliced.py b/test/test_sliced.py index 91e0961..08ab4fb 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -123,9 +123,7 @@ def test_sliced_backend(nx): n_projections = 20 - xb = nx.from_numpy(x) - yb = nx.from_numpy(y) - Pb = nx.from_numpy(P) + xb, yb, Pb = nx.from_numpy(x, y, P) val0 = ot.sliced_wasserstein_distance(x, y, projections=P) @@ -153,9 +151,7 @@ def test_sliced_backend_type_devices(nx): for tp in nx.__type_list__: print(nx.dtype_device(tp)) - xb = nx.from_numpy(x, type_as=tp) - yb = nx.from_numpy(y, type_as=tp) - Pb = nx.from_numpy(P, type_as=tp) + xb, yb, Pb = nx.from_numpy(x, y, P, type_as=tp) valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb) @@ -174,17 +170,13 @@ def test_sliced_backend_device_tf(): # Check that everything stays on the CPU with tf.device("/CPU:0"): - xb = nx.from_numpy(x) - yb = nx.from_numpy(y) - Pb = nx.from_numpy(P) + xb, yb, Pb = nx.from_numpy(x, y, P) valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb) nx.assert_same_dtype_device(xb, valb) if len(tf.config.list_physical_devices('GPU')) > 0: # Check that everything happens on the GPU - xb = nx.from_numpy(x) - yb = nx.from_numpy(y) - Pb = nx.from_numpy(P) + xb, yb, Pb = nx.from_numpy(x, y, P) valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb) nx.assert_same_dtype_device(xb, valb) assert nx.dtype_device(valb)[1].startswith("GPU") @@ -203,9 +195,7 @@ def test_max_sliced_backend(nx): n_projections = 20 - xb = nx.from_numpy(x) - yb = nx.from_numpy(y) - Pb = nx.from_numpy(P) + xb, yb, Pb = nx.from_numpy(x, y, P) val0 = ot.max_sliced_wasserstein_distance(x, y, projections=P) @@ -233,9 +223,7 @@ def test_max_sliced_backend_type_devices(nx): for tp in nx.__type_list__: print(nx.dtype_device(tp)) - xb = nx.from_numpy(x, type_as=tp) - yb = nx.from_numpy(y, type_as=tp) - Pb = nx.from_numpy(P, type_as=tp) + xb, yb, Pb = nx.from_numpy(x, y, P, type_as=tp) valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb) @@ -254,17 +242,13 @@ def test_max_sliced_backend_device_tf(): # Check that everything stays on the CPU with tf.device("/CPU:0"): - xb = nx.from_numpy(x) - yb = nx.from_numpy(y) - Pb = nx.from_numpy(P) + xb, yb, Pb = nx.from_numpy(x, y, P) valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb) nx.assert_same_dtype_device(xb, valb) if len(tf.config.list_physical_devices('GPU')) > 0: # Check that everything happens on the GPU - xb = nx.from_numpy(x) - yb = nx.from_numpy(y) - Pb = nx.from_numpy(P) + xb, yb, Pb = nx.from_numpy(x, y, P) valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb) nx.assert_same_dtype_device(xb, valb) assert nx.dtype_device(valb)[1].startswith("GPU") diff --git a/test/test_stochastic.py b/test/test_stochastic.py index 736df32..2b5c0fb 100644 --- a/test/test_stochastic.py +++ b/test/test_stochastic.py @@ -8,7 +8,8 @@ for descrete and semicontinous measures from the POT library. """ -# Author: Kilian Fatras <kilian.fatras@gmail.com> +# Authors: Kilian Fatras <kilian.fatras@gmail.com> +# RĂ©mi Flamary <remi.flamary@polytechnique.edu> # # License: MIT License @@ -213,3 +214,115 @@ def test_dual_sgd_sinkhorn(): G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-03) np.testing.assert_allclose( G_sgd, G_sinkhorn, atol=1e-03) # cf convergence sgd + + +def test_loss_dual_entropic(nx): + + nx.seed(0) + + xs = nx.randn(50, 2) + xt = nx.randn(40, 2) + 2 + + ws = nx.rand(50) + ws = ws / nx.sum(ws) + + wt = nx.rand(40) + wt = wt / nx.sum(wt) + + u = nx.randn(50) + v = nx.randn(40) + + def metric(x, y): + return -nx.dot(x, y.T) + + ot.stochastic.loss_dual_entropic(u, v, xs, xt) + + ot.stochastic.loss_dual_entropic(u, v, xs, xt, ws=ws, wt=wt, metric=metric) + + +def test_plan_dual_entropic(nx): + + nx.seed(0) + + xs = nx.randn(50, 2) + xt = nx.randn(40, 2) + 2 + + ws = nx.rand(50) + ws = ws / nx.sum(ws) + + wt = nx.rand(40) + wt = wt / nx.sum(wt) + + u = nx.randn(50) + v = nx.randn(40) + + def metric(x, y): + return -nx.dot(x, y.T) + + G1 = ot.stochastic.plan_dual_entropic(u, v, xs, xt) + + assert np.all(nx.to_numpy(G1) >= 0) + assert G1.shape[0] == 50 + assert G1.shape[1] == 40 + + G2 = ot.stochastic.plan_dual_entropic(u, v, xs, xt, ws=ws, wt=wt, metric=metric) + + assert np.all(nx.to_numpy(G2) >= 0) + assert G2.shape[0] == 50 + assert G2.shape[1] == 40 + + +def test_loss_dual_quadratic(nx): + + nx.seed(0) + + xs = nx.randn(50, 2) + xt = nx.randn(40, 2) + 2 + + ws = nx.rand(50) + ws = ws / nx.sum(ws) + + wt = nx.rand(40) + wt = wt / nx.sum(wt) + + u = nx.randn(50) + v = nx.randn(40) + + def metric(x, y): + return -nx.dot(x, y.T) + + ot.stochastic.loss_dual_quadratic(u, v, xs, xt) + + ot.stochastic.loss_dual_quadratic(u, v, xs, xt, ws=ws, wt=wt, metric=metric) + + +def test_plan_dual_quadratic(nx): + + nx.seed(0) + + xs = nx.randn(50, 2) + xt = nx.randn(40, 2) + 2 + + ws = nx.rand(50) + ws = ws / nx.sum(ws) + + wt = nx.rand(40) + wt = wt / nx.sum(wt) + + u = nx.randn(50) + v = nx.randn(40) + + def metric(x, y): + return -nx.dot(x, y.T) + + G1 = ot.stochastic.plan_dual_quadratic(u, v, xs, xt) + + assert np.all(nx.to_numpy(G1) >= 0) + assert G1.shape[0] == 50 + assert G1.shape[1] == 40 + + G2 = ot.stochastic.plan_dual_quadratic(u, v, xs, xt, ws=ws, wt=wt, metric=metric) + + assert np.all(nx.to_numpy(G2) >= 0) + assert G2.shape[0] == 50 + assert G2.shape[1] == 40 diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index e8349d1..02b3fc3 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -1,6 +1,7 @@ """Tests for module Unbalanced OT with entropy regularization""" # Author: Hicham Janati <hicham.janati@inria.fr> +# Laetitia Chapel <laetitia.chapel@univ-ubs.fr> # # License: MIT License @@ -9,11 +10,9 @@ import ot import pytest from ot.unbalanced import barycenter_unbalanced -from scipy.special import logsumexp - @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) -def test_unbalanced_convergence(method): +def test_unbalanced_convergence(nx, method): # test generalized sinkhorn for unbalanced OT n = 100 rng = np.random.RandomState(42) @@ -28,36 +27,51 @@ def test_unbalanced_convergence(method): epsilon = 1. reg_m = 1. + a, b, M = nx.from_numpy(a, b, M) + G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, reg_m=reg_m, method=method, log=True, verbose=True) - loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, - method=method, - verbose=True) + loss = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, epsilon, reg_m, method=method, verbose=True + )) # check fixed point equations # in log-domain fi = reg_m / (reg_m + epsilon) - logb = np.log(b + 1e-16) - loga = np.log(a + 1e-16) - logKtu = logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1) - logKv = logsumexp(log["logv"][None, :] - M / epsilon, axis=1) + logb = nx.log(b + 1e-16) + loga = nx.log(a + 1e-16) + logKtu = nx.logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1) + logKv = nx.logsumexp(log["logv"][None, :] - M / epsilon, axis=1) v_final = fi * (logb - logKtu) u_final = fi * (loga - logKv) np.testing.assert_allclose( - u_final, log["logu"], atol=1e-05) + nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05) np.testing.assert_allclose( - v_final, log["logv"], atol=1e-05) + nx.to_numpy(v_final), nx.to_numpy(log["logv"]), atol=1e-05) # check if sinkhorn_unbalanced2 returns the correct loss - np.testing.assert_allclose((G * M).sum(), loss, atol=1e-5) + np.testing.assert_allclose(nx.to_numpy(nx.sum(G * M)), loss, atol=1e-5) + + # check in case no histogram is provided + M_np = nx.to_numpy(M) + a_np, b_np = np.array([]), np.array([]) + a, b = nx.from_numpy(a_np, b_np) + + G = ot.unbalanced.sinkhorn_unbalanced( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, verbose=True + ) + G_np = ot.unbalanced.sinkhorn_unbalanced( + a_np, b_np, M_np, reg=epsilon, reg_m=reg_m, method=method, verbose=True + ) + np.testing.assert_allclose(G_np, nx.to_numpy(G)) @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) -def test_unbalanced_multiple_inputs(method): +def test_unbalanced_multiple_inputs(nx, method): # test generalized sinkhorn for unbalanced OT n = 100 rng = np.random.RandomState(42) @@ -72,6 +86,8 @@ def test_unbalanced_multiple_inputs(method): epsilon = 1. reg_m = 1. + a, b, M = nx.from_numpy(a, b, M) + loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, reg_m=reg_m, method=method, @@ -80,23 +96,24 @@ def test_unbalanced_multiple_inputs(method): # check fixed point equations # in log-domain fi = reg_m / (reg_m + epsilon) - logb = np.log(b + 1e-16) - loga = np.log(a + 1e-16)[:, None] - logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon, - axis=0) - logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) + logb = nx.log(b + 1e-16) + loga = nx.log(a + 1e-16)[:, None] + logKtu = nx.logsumexp( + log["logu"][:, None, :] - M[:, :, None] / epsilon, axis=0 + ) + logKv = nx.logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) v_final = fi * (logb - logKtu) u_final = fi * (loga - logKv) np.testing.assert_allclose( - u_final, log["logu"], atol=1e-05) + nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05) np.testing.assert_allclose( - v_final, log["logv"], atol=1e-05) + nx.to_numpy(v_final), nx.to_numpy(log["logv"]), atol=1e-05) assert len(loss) == b.shape[1] -def test_stabilized_vs_sinkhorn(): +def test_stabilized_vs_sinkhorn(nx): # test if stable version matches sinkhorn n = 100 @@ -112,19 +129,27 @@ def test_stabilized_vs_sinkhorn(): M /= np.median(M) epsilon = 0.1 reg_m = 1. - G, log = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg=epsilon, - method="sinkhorn_stabilized", - reg_m=reg_m, - log=True, - verbose=True) - G2, log2 = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, - method="sinkhorn", log=True) + + ab, bb, Mb = nx.from_numpy(a, b, M) + + G, _ = ot.unbalanced.sinkhorn_unbalanced2( + ab, bb, Mb, epsilon, reg_m, method="sinkhorn_stabilized", log=True + ) + G2, _ = ot.unbalanced.sinkhorn_unbalanced2( + ab, bb, Mb, epsilon, reg_m, method="sinkhorn", log=True + ) + G2_np, _ = ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, epsilon, reg_m, method="sinkhorn", log=True + ) + G = nx.to_numpy(G) + G2 = nx.to_numpy(G2) np.testing.assert_allclose(G, G2, atol=1e-5) + np.testing.assert_allclose(G2, G2_np, atol=1e-5) @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) -def test_unbalanced_barycenter(method): +def test_unbalanced_barycenter(nx, method): # test generalized sinkhorn for unbalanced OT barycenter n = 100 rng = np.random.RandomState(42) @@ -138,25 +163,29 @@ def test_unbalanced_barycenter(method): epsilon = 1. reg_m = 1. - q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, - method=method, log=True, verbose=True) + A, M = nx.from_numpy(A, M) + + q, log = barycenter_unbalanced( + A, M, reg=epsilon, reg_m=reg_m, method=method, log=True, verbose=True + ) # check fixed point equations fi = reg_m / (reg_m + epsilon) - logA = np.log(A + 1e-16) - logq = np.log(q + 1e-16)[:, None] - logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon, - axis=0) - logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) + logA = nx.log(A + 1e-16) + logq = nx.log(q + 1e-16)[:, None] + logKtu = nx.logsumexp( + log["logu"][:, None, :] - M[:, :, None] / epsilon, axis=0 + ) + logKv = nx.logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) v_final = fi * (logq - logKtu) u_final = fi * (logA - logKv) np.testing.assert_allclose( - u_final, log["logu"], atol=1e-05) + nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05) np.testing.assert_allclose( - v_final, log["logv"], atol=1e-05) + nx.to_numpy(v_final), nx.to_numpy(log["logv"]), atol=1e-05) -def test_barycenter_stabilized_vs_sinkhorn(): +def test_barycenter_stabilized_vs_sinkhorn(nx): # test generalized sinkhorn for unbalanced OT barycenter n = 100 rng = np.random.RandomState(42) @@ -170,21 +199,24 @@ def test_barycenter_stabilized_vs_sinkhorn(): epsilon = 0.5 reg_m = 10 - qstable, log = barycenter_unbalanced(A, M, reg=epsilon, - reg_m=reg_m, log=True, - tau=100, - method="sinkhorn_stabilized", - verbose=True - ) - q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, - method="sinkhorn", - log=True) + Ab, Mb = nx.from_numpy(A, M) - np.testing.assert_allclose( - q, qstable, atol=1e-05) + qstable, _ = barycenter_unbalanced( + Ab, Mb, reg=epsilon, reg_m=reg_m, log=True, tau=100, + method="sinkhorn_stabilized", verbose=True + ) + q, _ = barycenter_unbalanced( + Ab, Mb, reg=epsilon, reg_m=reg_m, method="sinkhorn", log=True + ) + q_np, _ = barycenter_unbalanced( + A, M, reg=epsilon, reg_m=reg_m, method="sinkhorn", log=True + ) + q, qstable = nx.to_numpy(q, qstable) + np.testing.assert_allclose(q, qstable, atol=1e-05) + np.testing.assert_allclose(q, q_np, atol=1e-05) -def test_wrong_method(): +def test_wrong_method(nx): n = 10 rng = np.random.RandomState(42) @@ -199,19 +231,20 @@ def test_wrong_method(): epsilon = 1. reg_m = 1. + a, b, M = nx.from_numpy(a, b, M) + with pytest.raises(ValueError): - ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, - reg_m=reg_m, - method='badmethod', - log=True, - verbose=True) + ot.unbalanced.sinkhorn_unbalanced( + a, b, M, reg=epsilon, reg_m=reg_m, method='badmethod', + log=True, verbose=True + ) with pytest.raises(ValueError): - ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, - method='badmethod', - verbose=True) + ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, epsilon, reg_m, method='badmethod', verbose=True + ) -def test_implemented_methods(): +def test_implemented_methods(nx): IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized'] TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling'] NOT_VALID_TOKENS = ['foo'] @@ -228,6 +261,9 @@ def test_implemented_methods(): M = ot.dist(x, x) epsilon = 1. reg_m = 1. + + a, b, M, A = nx.from_numpy(a, b, M, A) + for method in IMPLEMENTED_METHODS: ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m, method=method) @@ -251,3 +287,52 @@ def test_implemented_methods(): method=method) barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, method=method) + + +def test_mm_convergence(nx): + n = 100 + rng = np.random.RandomState(42) + x = rng.randn(n, 2) + rng = np.random.RandomState(75) + y = rng.randn(n, 2) + a = ot.utils.unif(n) + b = ot.utils.unif(n) + + M = ot.dist(x, y) + M = M / M.max() + reg_m = 100 + a, b, M = nx.from_numpy(a, b, M) + + G_kl, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', + verbose=True, log=True) + loss_kl = nx.to_numpy(ot.unbalanced.mm_unbalanced2( + a, b, M, reg_m, div='kl', verbose=True)) + G_l2, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', + verbose=False, log=True) + + # check if the marginals come close to the true ones when large reg + np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 1), a, atol=1e-03) + np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 0), b, atol=1e-03) + np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 1), a, atol=1e-03) + np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 0), b, atol=1e-03) + + # check if mm_unbalanced2 returns the correct loss + np.testing.assert_allclose(nx.to_numpy(nx.sum(G_kl * M)), loss_kl, + atol=1e-5) + + # check in case no histogram is provided + a_np, b_np = np.array([]), np.array([]) + a, b = nx.from_numpy(a_np, b_np) + + G_kl_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl') + G_l2_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2') + np.testing.assert_allclose(G_kl_null, G_kl) + np.testing.assert_allclose(G_l2_null, G_l2) + + # test when G0 is given + G0 = ot.emd(a, b, M) + reg_m = 10000 + G_kl = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', G0=G0) + G_l2 = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', G0=G0) + np.testing.assert_allclose(G0, G_kl, atol=1e-05) + np.testing.assert_allclose(G0, G_l2, atol=1e-05) diff --git a/test/test_utils.py b/test/test_utils.py index 6b476b2..3cfd295 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -62,12 +62,14 @@ def test_tic_toc(): import time ot.tic() - time.sleep(0.5) + time.sleep(0.1) t = ot.toc() t2 = ot.toq() # test timing - np.testing.assert_allclose(0.5, t, rtol=1e-1, atol=1e-1) + # np.testing.assert_allclose(0.1, t, rtol=1e-1, atol=1e-1) + # very slow macos github action equality not possible + assert t > 0.09 # test toc vs toq np.testing.assert_allclose(t, t2, rtol=1e-1, atol=1e-1) @@ -94,10 +96,22 @@ def test_unif(): np.testing.assert_allclose(1, np.sum(u)) -def test_dist(): +def test_unif_backend(nx): n = 100 + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + u = ot.unif(n, type_as=tp) + + np.testing.assert_allclose(1, np.sum(nx.to_numpy(u)), atol=1e-6) + + +def test_dist(): + + n = 10 + rng = np.random.RandomState(0) x = rng.randn(n, 2) @@ -122,7 +136,7 @@ def test_dist(): 'braycurtis', 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', - 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule' + 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'yule' ] # those that support weights metrics = ['mahalanobis', 'seuclidean'] # do not support weights depending on scipy's version diff --git a/test/test_weak.py b/test/test_weak.py new file mode 100644 index 0000000..945efb1 --- /dev/null +++ b/test/test_weak.py @@ -0,0 +1,52 @@ +"""Tests for main module ot.weak """ + +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + +import ot +import numpy as np + + +def test_weak_ot(): + # test weak ot solver and identity stationary point + n = 50 + rng = np.random.RandomState(0) + + xs = rng.randn(n, 2) + xt = rng.randn(n, 2) + u = ot.utils.unif(n) + + G, log = ot.weak_optimal_transport(xs, xt, u, u, log=True) + + # check constraints + np.testing.assert_allclose(u, G.sum(1)) + np.testing.assert_allclose(u, G.sum(0)) + + # chaeck that identity is recovered + G = ot.weak_optimal_transport(xs, xs, G0=np.eye(n) / n) + + # check G is identity + np.testing.assert_allclose(G, np.eye(n) / n) + + # check constraints + np.testing.assert_allclose(u, G.sum(1)) + np.testing.assert_allclose(u, G.sum(0)) + + +def test_weak_ot_bakends(nx): + # test weak ot solver for different backends + n = 50 + rng = np.random.RandomState(0) + + xs = rng.randn(n, 2) + xt = rng.randn(n, 2) + u = ot.utils.unif(n) + + G = ot.weak_optimal_transport(xs, xt, u, u) + + xs2, xt2, u2 = nx.from_numpy(xs, xt, u) + + G2 = ot.weak_optimal_transport(xs2, xt2, u2, u2) + + np.testing.assert_allclose(nx.to_numpy(G2), G) |