From 50c0f17d00e3492c4d56a356af30cf00d6d07913 Mon Sep 17 00:00:00 2001 From: Cédric Vincent-Cuaz Date: Fri, 11 Feb 2022 10:53:38 +0100 Subject: [MRG] GW dictionary learning (#319) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add fgw dictionary learning feature * add fgw dictionary learning feature * plot gromov wasserstein dictionary learning * Update __init__.py * fix pep8 errors exact E501 line too long * fix last pep8 issues * add unitary tests for (F)GW dictionary learning without using autodifferentiable functions * correct tests for (F)GW dictionary learning without using autodiff * correct tests for (F)GW dictionary learning without using autodiff * fix docs and notations * answer to review: improve tests, docs, examples + make node weights optional * fix pep8 and examples * improve docs + tests + thumbnail * make example faster * improve ex * update README.md * make GDL tests faster Co-authored-by: Rémi Flamary --- test/test_gromov.py | 554 ++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 542 insertions(+), 12 deletions(-) (limited to 'test/test_gromov.py') diff --git a/test/test_gromov.py b/test/test_gromov.py index 4b995d5..329f99c 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -3,6 +3,7 @@ # Author: Erwan Vautier # Nicolas Courty # Titouan Vayer +# Cédric Vincent-Cuaz # # License: MIT License @@ -26,6 +27,7 @@ def test_gromov(nx): p = ot.unif(n_samples) q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] C1 = ot.dist(xs, xs) C2 = ot.dist(xt, xt) @@ -37,9 +39,10 @@ def test_gromov(nx): C2b = nx.from_numpy(C2) pb = nx.from_numpy(p) qb = nx.from_numpy(q) + G0b = nx.from_numpy(G0) - G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', verbose=True) - Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)) + G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', G0=G0, verbose=True) + Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True)) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) @@ -56,9 +59,9 @@ def test_gromov(nx): gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=True) gwb = nx.to_numpy(gwb) - gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=False) + gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', G0=G0, log=False) gw_valb = nx.to_numpy( - ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False) + ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False) ) G = log['T'] @@ -91,6 +94,7 @@ def test_gromov_dtype_device(nx): p = ot.unif(n_samples) q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] C1 = ot.dist(xs, xs) C2 = ot.dist(xt, xt) @@ -105,9 +109,10 @@ def test_gromov_dtype_device(nx): C2b = nx.from_numpy(C2, type_as=tp) pb = nx.from_numpy(p, type_as=tp) qb = nx.from_numpy(q, type_as=tp) + G0b = nx.from_numpy(G0, type_as=tp) - Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True) - gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False) + Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True) + gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False) nx.assert_same_dtype_device(C1b, Gb) nx.assert_same_dtype_device(C1b, gw_valb) @@ -123,6 +128,7 @@ def test_gromov_device_tf(): xt = xs[::-1].copy() p = ot.unif(n_samples) q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] C1 = ot.dist(xs, xs) C2 = ot.dist(xt, xt) C1 /= C1.max() @@ -134,8 +140,9 @@ def test_gromov_device_tf(): C2b = nx.from_numpy(C2) pb = nx.from_numpy(p) qb = nx.from_numpy(q) - Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True) - gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False) + G0b = nx.from_numpy(G0) + Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True) + gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False) nx.assert_same_dtype_device(C1b, Gb) nx.assert_same_dtype_device(C1b, gw_valb) @@ -145,6 +152,7 @@ def test_gromov_device_tf(): C2b = nx.from_numpy(C2) pb = nx.from_numpy(p) qb = nx.from_numpy(q) + G0b = nx.from_numpy(G0b) Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True) gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False) nx.assert_same_dtype_device(C1b, Gb) @@ -554,6 +562,7 @@ def test_fgw(nx): p = ot.unif(n_samples) q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] C1 = ot.dist(xs, xs) C2 = ot.dist(xt, xt) @@ -569,9 +578,10 @@ def test_fgw(nx): C2b = nx.from_numpy(C2) pb = nx.from_numpy(p) qb = nx.from_numpy(q) + G0b = nx.from_numpy(G0) - G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True) - Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True) + G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True) + Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, G0=G0b, log=True) Gb = nx.to_numpy(Gb) # check constraints @@ -586,8 +596,8 @@ def test_fgw(nx): np.testing.assert_allclose( Gb, np.flipud(Id), atol=1e-04) # cf convergence gromov - fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True) - fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True) + fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', G0=None, alpha=0.5, log=True) + fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', G0=G0b, alpha=0.5, log=True) fgwb = nx.to_numpy(fgwb) G = log['T'] @@ -698,3 +708,523 @@ def test_fgw_barycenter(nx): Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb) np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1])) + + +def test_gromov_wasserstein_linear_unmixing(nx): + n = 10 + + X1, y1 = ot.datasets.make_data_classif('3gauss', n, random_state=42) + X2, y2 = ot.datasets.make_data_classif('3gauss2', n, random_state=42) + + C1 = ot.dist(X1) + C2 = ot.dist(X2) + Cdict = np.stack([C1, C2]) + p = ot.unif(n) + + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + Cdictb = nx.from_numpy(Cdict) + pb = nx.from_numpy(p) + tol = 10**(-5) + # Tests without regularization + reg = 0. + unmixing1, C1_emb, OT, reconstruction1 = ot.gromov.gromov_wasserstein_linear_unmixing( + C1, Cdict, reg=reg, p=p, q=p, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + unmixing1b, C1b_emb, OTb, reconstruction1b = ot.gromov.gromov_wasserstein_linear_unmixing( + C1b, Cdictb, reg=reg, p=None, q=None, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2, C2_emb, OT, reconstruction2 = ot.gromov.gromov_wasserstein_linear_unmixing( + C2, Cdict, reg=reg, p=None, q=None, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2b, C2b_emb, OTb, reconstruction2b = ot.gromov.gromov_wasserstein_linear_unmixing( + C2b, Cdictb, reg=reg, p=pb, q=pb, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06) + np.testing.assert_allclose(unmixing1, [1., 0.], atol=1e-01) + np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=1e-06) + np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01) + np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-06) + np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-06) + np.testing.assert_allclose(reconstruction1, reconstruction1b, atol=1e-06) + np.testing.assert_allclose(reconstruction2, reconstruction2b, atol=1e-06) + np.testing.assert_allclose(C1b_emb.shape, (n, n)) + np.testing.assert_allclose(C2b_emb.shape, (n, n)) + + # Tests with regularization + + reg = 0.001 + unmixing1, C1_emb, OT, reconstruction1 = ot.gromov.gromov_wasserstein_linear_unmixing( + C1, Cdict, reg=reg, p=p, q=p, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + unmixing1b, C1b_emb, OTb, reconstruction1b = ot.gromov.gromov_wasserstein_linear_unmixing( + C1b, Cdictb, reg=reg, p=None, q=None, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2, C2_emb, OT, reconstruction2 = ot.gromov.gromov_wasserstein_linear_unmixing( + C2, Cdict, reg=reg, p=None, q=None, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2b, C2b_emb, OTb, reconstruction2b = ot.gromov.gromov_wasserstein_linear_unmixing( + C2b, Cdictb, reg=reg, p=pb, q=pb, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06) + np.testing.assert_allclose(unmixing1, [1., 0.], atol=1e-01) + np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=1e-06) + np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01) + np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-06) + np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-06) + np.testing.assert_allclose(reconstruction1, reconstruction1b, atol=1e-06) + np.testing.assert_allclose(reconstruction2, reconstruction2b, atol=1e-06) + np.testing.assert_allclose(C1b_emb.shape, (n, n)) + np.testing.assert_allclose(C2b_emb.shape, (n, n)) + + +def test_gromov_wasserstein_dictionary_learning(nx): + + # create dataset composed from 2 structures which are repeated 5 times + shape = 10 + n_samples = 2 + n_atoms = 2 + projection = 'nonnegative_symmetric' + X1, y1 = ot.datasets.make_data_classif('3gauss', shape, random_state=42) + X2, y2 = ot.datasets.make_data_classif('3gauss2', shape, random_state=42) + C1 = ot.dist(X1) + C2 = ot.dist(X2) + Cs = [C1.copy() for _ in range(n_samples // 2)] + [C2.copy() for _ in range(n_samples // 2)] + ps = [ot.unif(shape) for _ in range(n_samples)] + q = ot.unif(shape) + + # Provide initialization for the graph dictionary of shape (n_atoms, shape, shape) + # following the same procedure than implemented in gromov_wasserstein_dictionary_learning. + dataset_means = [C.mean() for C in Cs] + np.random.seed(0) + Cdict_init = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(n_atoms, shape, shape)) + if projection == 'nonnegative_symmetric': + Cdict_init = 0.5 * (Cdict_init + Cdict_init.transpose((0, 2, 1))) + Cdict_init[Cdict_init < 0.] = 0. + Csb = [nx.from_numpy(C) for C in Cs] + psb = [nx.from_numpy(p) for p in ps] + qb = nx.from_numpy(q) + Cdict_initb = nx.from_numpy(Cdict_init) + + # Test: compare reconstruction error using initial dictionary and dictionary learned using this initialization + # > Compute initial reconstruction of samples on this random dictionary without backend + use_adam_optimizer = True + verbose = False + tol = 10**(-5) + epochs = 1 + + initial_total_reconstruction = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Cs[i], Cdict_init, p=ps[i], q=q, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + initial_total_reconstruction += reconstruction + + # > Learn the dictionary using this init + Cdict, log = ot.gromov.gromov_wasserstein_dictionary_learning( + Cs, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict_init, + epochs=epochs, batch_size=2 * n_samples, learning_rate=1., reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary without backend + total_reconstruction = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Cs[i], Cdict, p=None, q=None, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction += reconstruction + + np.testing.assert_array_less(total_reconstruction, initial_total_reconstruction) + + # Test: Perform same experiments after going through backend + + Cdictb, log = ot.gromov.gromov_wasserstein_dictionary_learning( + Csb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdict_initb, + epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # Compute reconstruction of samples on learned dictionary + total_reconstruction_b = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Csb[i], Cdictb, p=psb[i], q=qb, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_b += reconstruction + + np.testing.assert_array_less(total_reconstruction_b, initial_total_reconstruction) + np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05) + np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05) + np.testing.assert_allclose(Cdict, nx.to_numpy(Cdictb), atol=1e-03) + + # Test: Perform same comparison without providing the initial dictionary being an optional input + # knowing than the initialization scheme is the same than implemented to set the benchmarked initialization. + np.random.seed(0) + Cdict_bis, log = ot.gromov.gromov_wasserstein_dictionary_learning( + Cs, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, + epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_bis = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Cs[i], Cdict_bis, p=ps[i], q=q, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_bis += reconstruction + + np.testing.assert_allclose(total_reconstruction_bis, total_reconstruction, atol=1e-05) + + # Test: Same after going through backend + np.random.seed(0) + Cdictb_bis, log = ot.gromov.gromov_wasserstein_dictionary_learning( + Csb, D=n_atoms, nt=shape, ps=psb, q=qb, Cdict_init=None, + epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_b_bis = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Csb[i], Cdictb_bis, p=None, q=None, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_b_bis += reconstruction + + np.testing.assert_allclose(total_reconstruction_b_bis, total_reconstruction_b, atol=1e-05) + np.testing.assert_allclose(Cdict_bis, nx.to_numpy(Cdictb_bis), atol=1e-03) + + # Test: Perform same comparison without providing the initial dictionary being an optional input + # and testing other optimization settings untested until now. + # We pass previously estimated dictionaries to speed up the process. + use_adam_optimizer = False + verbose = True + use_log = True + + np.random.seed(0) + Cdict_bis2, log = ot.gromov.gromov_wasserstein_dictionary_learning( + Cs, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict, + epochs=epochs, batch_size=n_samples, learning_rate=10., reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_bis2 = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Cs[i], Cdict_bis2, p=ps[i], q=q, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_bis2 += reconstruction + + np.testing.assert_array_less(total_reconstruction_bis2, total_reconstruction) + + # Test: Same after going through backend + np.random.seed(0) + Cdictb_bis2, log = ot.gromov.gromov_wasserstein_dictionary_learning( + Csb, D=n_atoms, nt=shape, ps=psb, q=qb, Cdict_init=Cdictb, + epochs=epochs, batch_size=n_samples, learning_rate=10., reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_b_bis2 = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Csb[i], Cdictb_bis2, p=psb[i], q=qb, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_b_bis2 += reconstruction + + np.testing.assert_allclose(total_reconstruction_b_bis2, total_reconstruction_bis2, atol=1e-05) + + +def test_fused_gromov_wasserstein_linear_unmixing(nx): + + n = 10 + X1, y1 = ot.datasets.make_data_classif('3gauss', n, random_state=42) + X2, y2 = ot.datasets.make_data_classif('3gauss2', n, random_state=42) + F, y = ot.datasets.make_data_classif('3gauss', n, random_state=42) + + C1 = ot.dist(X1) + C2 = ot.dist(X2) + Cdict = np.stack([C1, C2]) + Ydict = np.stack([F, F]) + p = ot.unif(n) + + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + Fb = nx.from_numpy(F) + Cdictb = nx.from_numpy(Cdict) + Ydictb = nx.from_numpy(Ydict) + pb = nx.from_numpy(p) + # Tests without regularization + reg = 0. + + unmixing1, C1_emb, Y1_emb, OT, reconstruction1 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C1, F, Cdict, Ydict, p=p, q=p, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + ) + + unmixing1b, C1b_emb, Y1b_emb, OTb, reconstruction1b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C1b, Fb, Cdictb, Ydictb, p=None, q=None, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2, C2_emb, Y2_emb, OT, reconstruction2 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C2, F, Cdict, Ydict, p=None, q=None, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2b, C2b_emb, Y2b_emb, OTb, reconstruction2b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C2b, Fb, Cdictb, Ydictb, p=pb, q=pb, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + ) + + np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06) + np.testing.assert_allclose(unmixing1, [1., 0.], atol=1e-01) + np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=1e-06) + np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01) + np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-03) + np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-03) + np.testing.assert_allclose(Y1_emb, nx.to_numpy(Y1b_emb), atol=1e-03) + np.testing.assert_allclose(Y2_emb, nx.to_numpy(Y2b_emb), atol=1e-03) + np.testing.assert_allclose(reconstruction1, reconstruction1b, atol=1e-06) + np.testing.assert_allclose(reconstruction2, reconstruction2b, atol=1e-06) + np.testing.assert_allclose(C1b_emb.shape, (n, n)) + np.testing.assert_allclose(C2b_emb.shape, (n, n)) + + # Tests with regularization + reg = 0.001 + + unmixing1, C1_emb, Y1_emb, OT, reconstruction1 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C1, F, Cdict, Ydict, p=p, q=p, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + ) + + unmixing1b, C1b_emb, Y1b_emb, OTb, reconstruction1b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C1b, Fb, Cdictb, Ydictb, p=None, q=None, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2, C2_emb, Y2_emb, OT, reconstruction2 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C2, F, Cdict, Ydict, p=None, q=None, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2b, C2b_emb, Y2b_emb, OTb, reconstruction2b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C2b, Fb, Cdictb, Ydictb, p=pb, q=pb, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + ) + + np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06) + np.testing.assert_allclose(unmixing1, [1., 0.], atol=1e-01) + np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=1e-06) + np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01) + np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-03) + np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-03) + np.testing.assert_allclose(Y1_emb, nx.to_numpy(Y1b_emb), atol=1e-03) + np.testing.assert_allclose(Y2_emb, nx.to_numpy(Y2b_emb), atol=1e-03) + np.testing.assert_allclose(reconstruction1, reconstruction1b, atol=1e-06) + np.testing.assert_allclose(reconstruction2, reconstruction2b, atol=1e-06) + np.testing.assert_allclose(C1b_emb.shape, (n, n)) + np.testing.assert_allclose(C2b_emb.shape, (n, n)) + + +def test_fused_gromov_wasserstein_dictionary_learning(nx): + + # create dataset composed from 2 structures which are repeated 5 times + shape = 10 + n_samples = 2 + n_atoms = 2 + projection = 'nonnegative_symmetric' + X1, y1 = ot.datasets.make_data_classif('3gauss', shape, random_state=42) + X2, y2 = ot.datasets.make_data_classif('3gauss2', shape, random_state=42) + F, y = ot.datasets.make_data_classif('3gauss', shape, random_state=42) + + C1 = ot.dist(X1) + C2 = ot.dist(X2) + Cs = [C1.copy() for _ in range(n_samples // 2)] + [C2.copy() for _ in range(n_samples // 2)] + Ys = [F.copy() for _ in range(n_samples)] + ps = [ot.unif(shape) for _ in range(n_samples)] + q = ot.unif(shape) + + # Provide initialization for the graph dictionary of shape (n_atoms, shape, shape) + # following the same procedure than implemented in gromov_wasserstein_dictionary_learning. + dataset_structure_means = [C.mean() for C in Cs] + np.random.seed(0) + Cdict_init = np.random.normal(loc=np.mean(dataset_structure_means), scale=np.std(dataset_structure_means), size=(n_atoms, shape, shape)) + if projection == 'nonnegative_symmetric': + Cdict_init = 0.5 * (Cdict_init + Cdict_init.transpose((0, 2, 1))) + Cdict_init[Cdict_init < 0.] = 0. + dataset_feature_means = np.stack([Y.mean(axis=0) for Y in Ys]) + Ydict_init = np.random.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(n_atoms, shape, 2)) + + Csb = [nx.from_numpy(C) for C in Cs] + Ysb = [nx.from_numpy(Y) for Y in Ys] + psb = [nx.from_numpy(p) for p in ps] + qb = nx.from_numpy(q) + Cdict_initb = nx.from_numpy(Cdict_init) + Ydict_initb = nx.from_numpy(Ydict_init) + + # Test: Compute initial reconstruction of samples on this random dictionary + alpha = 0.5 + use_adam_optimizer = True + verbose = False + tol = 1e-05 + epochs = 1 + + initial_total_reconstruction = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Cs[i], Ys[i], Cdict_init, Ydict_init, p=ps[i], q=q, + alpha=alpha, reg=0., tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + initial_total_reconstruction += reconstruction + + # > Learn a dictionary using this given initialization and check that the reconstruction loss + # on the learned dictionary is lower than the one using its initialization. + Cdict, Ydict, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( + Cs, Ys, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict_init, Ydict_init=Ydict_init, + epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Cs[i], Ys[i], Cdict, Ydict, p=None, q=None, alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction += reconstruction + # Compare both + np.testing.assert_array_less(total_reconstruction, initial_total_reconstruction) + + # Test: Perform same experiments after going through backend + + Cdictb, Ydictb, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( + Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdict_initb, Ydict_init=Ydict_initb, + epochs=epochs, batch_size=2 * n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_b = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Csb[i], Ysb[i], Cdictb, Ydictb, p=psb[i], q=qb, alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_b += reconstruction + + np.testing.assert_array_less(total_reconstruction_b, initial_total_reconstruction) + np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05) + np.testing.assert_allclose(Cdict, nx.to_numpy(Cdictb), atol=1e-03) + np.testing.assert_allclose(Ydict, nx.to_numpy(Ydictb), atol=1e-03) + + # Test: Perform similar experiment without providing the initial dictionary being an optional input + np.random.seed(0) + Cdict_bis, Ydict_bis, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( + Cs, Ys, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, Ydict_init=None, + epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_bis = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Cs[i], Ys[i], Cdict_bis, Ydict_bis, p=ps[i], q=q, alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_bis += reconstruction + + np.testing.assert_allclose(total_reconstruction_bis, total_reconstruction, atol=1e-05) + + # > Same after going through backend + np.random.seed(0) + Cdictb_bis, Ydictb_bis, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( + Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, Ydict_init=None, + epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_b_bis = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Csb[i], Ysb[i], Cdictb_bis, Ydictb_bis, p=psb[i], q=qb, alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_b_bis += reconstruction + np.testing.assert_allclose(total_reconstruction_b_bis, total_reconstruction_b, atol=1e-05) + + # Test: without using adam optimizer, with log and verbose set to True + use_adam_optimizer = False + verbose = True + use_log = True + + # > Experiment providing previously estimated dictionary to speed up the test compared to providing initial random init. + np.random.seed(0) + Cdict_bis2, Ydict_bis2, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( + Cs, Ys, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict, Ydict_init=Ydict, + epochs=epochs, batch_size=n_samples, learning_rate_C=10., learning_rate_Y=10., alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_bis2 = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Cs[i], Ys[i], Cdict_bis2, Ydict_bis2, p=ps[i], q=q, alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_bis2 += reconstruction + + np.testing.assert_array_less(total_reconstruction_bis2, total_reconstruction) + + # > Same after going through backend + np.random.seed(0) + Cdictb_bis2, Ydictb_bis2, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( + Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdictb, Ydict_init=Ydictb, + epochs=epochs, batch_size=n_samples, learning_rate_C=10., learning_rate_Y=10., alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_b_bis2 = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Csb[i], Ysb[i], Cdictb_bis2, Ydictb_bis2, p=None, q=None, alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_b_bis2 += reconstruction + + # > Compare results with/without backend + np.testing.assert_allclose(total_reconstruction_bis2, total_reconstruction_b_bis2, atol=1e-05) -- cgit v1.2.3 From 9412f0ad1c0003e659b7d779bf8b6728e0e5e60f Mon Sep 17 00:00:00 2001 From: Nathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com> Date: Wed, 2 Mar 2022 11:35:47 +0100 Subject: [MRG] Gromov_Wasserstein2 not performing backward properly on GPU (#352) * Resolves gromov wasserstein backward bug * release file updated --- RELEASES.md | 3 +++ ot/gromov.py | 12 +++++++---- test/test_gromov.py | 60 +++++++++++++++++++++++++++++++---------------------- 3 files changed, 46 insertions(+), 29 deletions(-) (limited to 'test/test_gromov.py') diff --git a/RELEASES.md b/RELEASES.md index c1068f3..18562e7 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -18,6 +18,9 @@ - Fix bug in instantiating an `autograd` function `ValFunction` (Issue #337, PR #338) - Fix POT ABI compatibility with old and new numpy (Issue #346, PR #349) +- Fix bug where gromov_wasserstein2 does not perform backpropagation with CUDA + tensors (Issue #351, PR #352) + ## 0.8.1.0 *December 2021* diff --git a/ot/gromov.py b/ot/gromov.py index f5a1f91..c5a82d1 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -546,8 +546,10 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo= gw = log_gw['gw_dist'] if loss_fun == 'square_loss': - gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)) - gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)) + gC1 = 2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T) + gC2 = 2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T) + gC1 = nx.from_numpy(gC1, type_as=C10) + gC2 = nx.from_numpy(gC2, type_as=C10) gw = nx.set_gradients(gw, (p0, q0, C10, C20), (log_gw['u'], log_gw['v'], gC1, gC2)) @@ -786,8 +788,10 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 log_fgw['T'] = T0 if loss_fun == 'square_loss': - gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)) - gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)) + gC1 = 2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T) + gC2 = 2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T) + gC1 = nx.from_numpy(gC1, type_as=C10) + gC2 = nx.from_numpy(gC2, type_as=C10) fgw_dist = nx.set_gradients(fgw_dist, (p0, q0, C10, C20, M0), (log_fgw['u'], log_fgw['v'], alpha * gC1, alpha * gC2, (1 - alpha) * T0)) diff --git a/test/test_gromov.py b/test/test_gromov.py index 329f99c..0dcf2da 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -181,19 +181,24 @@ def test_gromov2_gradients(): if torch: - p1 = torch.tensor(p, requires_grad=True) - q1 = torch.tensor(q, requires_grad=True) - C11 = torch.tensor(C1, requires_grad=True) - C12 = torch.tensor(C2, requires_grad=True) + devices = [torch.device("cpu")] + if torch.cuda.is_available(): + devices.append(torch.device("cuda")) + for device in devices: + p1 = torch.tensor(p, requires_grad=True, device=device) + q1 = torch.tensor(q, requires_grad=True, device=device) + C11 = torch.tensor(C1, requires_grad=True, device=device) + C12 = torch.tensor(C2, requires_grad=True, device=device) - val = ot.gromov_wasserstein2(C11, C12, p1, q1) + val = ot.gromov_wasserstein2(C11, C12, p1, q1) - val.backward() + val.backward() - assert q1.shape == q1.grad.shape - assert p1.shape == p1.grad.shape - assert C11.shape == C11.grad.shape - assert C12.shape == C12.grad.shape + assert val.device == p1.device + assert q1.shape == q1.grad.shape + assert p1.shape == p1.grad.shape + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape @pytest.skip_backend("jax", reason="test very slow with jax backend") @@ -636,21 +641,26 @@ def test_fgw2_gradients(): if torch: - p1 = torch.tensor(p, requires_grad=True) - q1 = torch.tensor(q, requires_grad=True) - C11 = torch.tensor(C1, requires_grad=True) - C12 = torch.tensor(C2, requires_grad=True) - M1 = torch.tensor(M, requires_grad=True) - - val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1) - - val.backward() - - assert q1.shape == q1.grad.shape - assert p1.shape == p1.grad.shape - assert C11.shape == C11.grad.shape - assert C12.shape == C12.grad.shape - assert M1.shape == M1.grad.shape + devices = [torch.device("cpu")] + if torch.cuda.is_available(): + devices.append(torch.device("cuda")) + for device in devices: + p1 = torch.tensor(p, requires_grad=True, device=device) + q1 = torch.tensor(q, requires_grad=True, device=device) + C11 = torch.tensor(C1, requires_grad=True, device=device) + C12 = torch.tensor(C2, requires_grad=True, device=device) + M1 = torch.tensor(M, requires_grad=True, device=device) + + val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1) + + val.backward() + + assert val.device == p1.device + assert q1.shape == q1.grad.shape + assert p1.shape == p1.grad.shape + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape + assert M1.shape == M1.grad.shape def test_fgw_barycenter(nx): -- cgit v1.2.3 From 767171593f2a98a26b9a39bf110a45085e3b982e Mon Sep 17 00:00:00 2001 From: Nathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com> Date: Thu, 24 Mar 2022 10:53:47 +0100 Subject: [MRG] Domain adaptation and unbalanced solvers with backend support (#343) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * First draft * Add matrix inverse and square root to backend * Eigen decomposition for older versions of pytorch (1.8.1 and older) * Corrected eigen decomposition for pytorch 1.8.1 and older * Spectral theorem is a thing * Optimization * small optimization * More functions converted * pep8 * remove a warning and prepare torch meshgrid for future torch release (which will change default indexing) * dots and pep8 * Meshgrid corrected for older version and prepared for future versions changes * New backend functions * Base transport * LinearTransport * All transport classes + pep8 * PR added to release file * Jcpot barycenter test * unbalanced with backend * pep8 * bug solve * test of domain adaptation with backends * solve bug for tic toc & macos * solving scipy deprecation warning * solving scipy deprecation warning attempt2 * solving scipy deprecation warning attempt3 * A warning is triggered when a float->int conversion is detected * bug solve * docs * release file updated * Better handling of float->int conversion in EMD * Corrected test for is_floating_point * docs * release file updated * cupy does not allow implicit cast * fromnumpy * added test * test da tf jax * test unbalanced with no provided histogram * using type_as argument in unif function correctly * pep8 * transport plan cast in emd changed behaviour, now trying to cast as histogram's dtype, defaulting to cost matrix Co-authored-by: Rémi Flamary --- RELEASES.md | 2 + ot/backend.py | 304 ++++++++++++++++++++++++++++++++++---- ot/bregman.py | 17 +-- ot/da.py | 382 +++++++++++++++++++++++++++--------------------- ot/lp/__init__.py | 83 ++++++++--- ot/optim.py | 11 +- ot/unbalanced.py | 302 +++++++++++++++++++------------------- ot/utils.py | 26 +++- test/test_1d_solver.py | 28 +--- test/test_backend.py | 66 ++++++++- test/test_bregman.py | 81 +++------- test/test_da.py | 307 +++++++++++++++++++------------------- test/test_gromov.py | 147 +++++++------------ test/test_optim.py | 17 +-- test/test_ot.py | 19 +-- test/test_sliced.py | 32 +--- test/test_unbalanced.py | 157 ++++++++++++-------- test/test_weak.py | 4 +- 18 files changed, 1160 insertions(+), 825 deletions(-) (limited to 'test/test_gromov.py') diff --git a/RELEASES.md b/RELEASES.md index 0f1f231..86b401a 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -11,6 +11,7 @@ of the regularization parameter (PR #336). - Backend implementation for `ot.lp.free_support_barycenter` (PR #340). - Add weak OT solver + example (PR #341). +- Add backend support for Domain Adaptation and Unbalanced solvers (PR #343). - Add (F)GW linear dictionary learning solvers + example (PR #319) - Add links to related PR and Issues in the doc release page (PR #350) @@ -19,6 +20,7 @@ - Fix bug in instantiating an `autograd` function `ValFunction` (Issue #337, PR #338) - Fix POT ABI compatibility with old and new numpy (Issue #346, PR #349) +- Warning when feeding integer cost matrix to EMD solver resulting in an integer transport plan (Issue #345, PR #343) - Fix bug where gromov_wasserstein2 does not perform backpropagation with CUDA tensors (Issue #351, PR #352) diff --git a/ot/backend.py b/ot/backend.py index 6e0bc3d..361ffba 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -87,7 +87,9 @@ Performance # License: MIT License import numpy as np -import scipy.special as scipy +import scipy +import scipy.linalg +import scipy.special as special from scipy.sparse import issparse, coo_matrix, csr_matrix import warnings import time @@ -102,7 +104,7 @@ except ImportError: try: import jax import jax.numpy as jnp - import jax.scipy.special as jscipy + import jax.scipy.special as jspecial from jax.lib import xla_bridge jax_type = jax.numpy.ndarray except ImportError: @@ -202,13 +204,29 @@ class Backend(): def __str__(self): return self.__name__ - # convert to numpy - def to_numpy(self, a): + # convert batch of tensors to numpy + def to_numpy(self, *arrays): + """Returns the numpy version of tensors""" + if len(arrays) == 1: + return self._to_numpy(arrays[0]) + else: + return [self._to_numpy(array) for array in arrays] + + # convert a tensor to numpy + def _to_numpy(self, a): """Returns the numpy version of a tensor""" raise NotImplementedError() - # convert from numpy - def from_numpy(self, a, type_as=None): + # convert batch of arrays from numpy + def from_numpy(self, *arrays, type_as=None): + """Creates tensors cloning a numpy array, with the given precision (defaulting to input's precision) and the given device (in case of GPUs)""" + if len(arrays) == 1: + return self._from_numpy(arrays[0], type_as=type_as) + else: + return [self._from_numpy(array, type_as=type_as) for array in arrays] + + # convert an array from numpy + def _from_numpy(self, a, type_as=None): """Creates a tensor cloning a numpy array, with the given precision (defaulting to input's precision) and the given device (in case of GPUs)""" raise NotImplementedError() @@ -536,6 +554,16 @@ class Backend(): """ raise NotImplementedError() + def argmin(self, a, axis=None): + r""" + Returns the indices of the minimum values of a tensor along given dimensions. + + This function follows the api from :any:`numpy.argmin` + + See: https://numpy.org/doc/stable/reference/generated/numpy.argmin.html + """ + raise NotImplementedError() + def mean(self, a, axis=None): r""" Computes the arithmetic mean of a tensor along given dimensions. @@ -786,6 +814,72 @@ class Backend(): """ raise NotImplementedError() + def solve(self, a, b): + r""" + Solves a linear matrix equation, or system of linear scalar equations. + + This function follows the api from :any:`numpy.linalg.solve`. + + See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.solve.html + """ + raise NotImplementedError() + + def trace(self, a): + r""" + Returns the sum along diagonals of the array. + + This function follows the api from :any:`numpy.trace`. + + See: https://numpy.org/doc/stable/reference/generated/numpy.trace.html + """ + raise NotImplementedError() + + def inv(self, a): + r""" + Computes the inverse of a matrix. + + This function follows the api from :any:`scipy.linalg.inv`. + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.inv.html + """ + raise NotImplementedError() + + def sqrtm(self, a): + r""" + Computes the matrix square root. Requires input to be definite positive. + + This function follows the api from :any:`scipy.linalg.sqrtm`. + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.sqrtm.html + """ + raise NotImplementedError() + + def isfinite(self, a): + r""" + Tests element-wise for finiteness (not infinity and not Not a Number). + + This function follows the api from :any:`numpy.isfinite`. + + See: https://numpy.org/doc/stable/reference/generated/numpy.isfinite.html + """ + raise NotImplementedError() + + def array_equal(self, a, b): + r""" + True if two arrays have the same shape and elements, False otherwise. + + This function follows the api from :any:`numpy.array_equal`. + + See: https://numpy.org/doc/stable/reference/generated/numpy.array_equal.html + """ + raise NotImplementedError() + + def is_floating_point(self, a): + r""" + Returns whether or not the input consists of floats + """ + raise NotImplementedError() + class NumpyBackend(Backend): """ @@ -802,10 +896,10 @@ class NumpyBackend(Backend): rng_ = np.random.RandomState() - def to_numpy(self, a): + def _to_numpy(self, a): return a - def from_numpy(self, a, type_as=None): + def _from_numpy(self, a, type_as=None): if type_as is None: return a elif isinstance(a, float): @@ -936,6 +1030,9 @@ class NumpyBackend(Backend): def argmax(self, a, axis=None): return np.argmax(a, axis=axis) + def argmin(self, a, axis=None): + return np.argmin(a, axis=axis) + def mean(self, a, axis=None): return np.mean(a, axis=axis) @@ -955,7 +1052,7 @@ class NumpyBackend(Backend): return np.unique(a) def logsumexp(self, a, axis=None): - return scipy.logsumexp(a, axis=axis) + return special.logsumexp(a, axis=axis) def stack(self, arrays, axis=0): return np.stack(arrays, axis) @@ -1004,8 +1101,11 @@ class NumpyBackend(Backend): else: return a - def where(self, condition, x, y): - return np.where(condition, x, y) + def where(self, condition, x=None, y=None): + if x is None and y is None: + return np.where(condition) + else: + return np.where(condition, x, y) def copy(self, a): return a.copy() @@ -1046,6 +1146,27 @@ class NumpyBackend(Backend): results[key] = (t1 - t0) / n_runs return results + def solve(self, a, b): + return np.linalg.solve(a, b) + + def trace(self, a): + return np.trace(a) + + def inv(self, a): + return scipy.linalg.inv(a) + + def sqrtm(self, a): + return scipy.linalg.sqrtm(a) + + def isfinite(self, a): + return np.isfinite(a) + + def array_equal(self, a, b): + return np.array_equal(a, b) + + def is_floating_point(self, a): + return a.dtype.kind == "f" + class JaxBackend(Backend): """ @@ -1075,13 +1196,15 @@ class JaxBackend(Backend): jax.device_put(jnp.array(1, dtype=jnp.float64), d) ] - def to_numpy(self, a): + def _to_numpy(self, a): return np.array(a) def _change_device(self, a, type_as): return jax.device_put(a, type_as.device_buffer.device()) - def from_numpy(self, a, type_as=None): + def _from_numpy(self, a, type_as=None): + if isinstance(a, float): + a = np.array(a) if type_as is None: return jnp.array(a) else: @@ -1216,6 +1339,9 @@ class JaxBackend(Backend): def argmax(self, a, axis=None): return jnp.argmax(a, axis=axis) + def argmin(self, a, axis=None): + return jnp.argmin(a, axis=axis) + def mean(self, a, axis=None): return jnp.mean(a, axis=axis) @@ -1235,7 +1361,7 @@ class JaxBackend(Backend): return jnp.unique(a) def logsumexp(self, a, axis=None): - return jscipy.logsumexp(a, axis=axis) + return jspecial.logsumexp(a, axis=axis) def stack(self, arrays, axis=0): return jnp.stack(arrays, axis) @@ -1293,8 +1419,11 @@ class JaxBackend(Backend): # Currently, JAX does not support sparse matrices return a - def where(self, condition, x, y): - return jnp.where(condition, x, y) + def where(self, condition, x=None, y=None): + if x is None and y is None: + return jnp.where(condition) + else: + return jnp.where(condition, x, y) def copy(self, a): # No need to copy, JAX arrays are immutable @@ -1339,6 +1468,28 @@ class JaxBackend(Backend): results[key] = (t1 - t0) / n_runs return results + def solve(self, a, b): + return jnp.linalg.solve(a, b) + + def trace(self, a): + return jnp.trace(a) + + def inv(self, a): + return jnp.linalg.inv(a) + + def sqrtm(self, a): + L, V = jnp.linalg.eigh(a) + return (V * jnp.sqrt(L)[None, :]) @ V.T + + def isfinite(self, a): + return jnp.isfinite(a) + + def array_equal(self, a, b): + return jnp.array_equal(a, b) + + def is_floating_point(self, a): + return a.dtype.kind == "f" + class TorchBackend(Backend): """ @@ -1384,10 +1535,10 @@ class TorchBackend(Backend): self.ValFunction = ValFunction - def to_numpy(self, a): + def _to_numpy(self, a): return a.cpu().detach().numpy() - def from_numpy(self, a, type_as=None): + def _from_numpy(self, a, type_as=None): if isinstance(a, float): a = np.array(a) if type_as is None: @@ -1564,6 +1715,9 @@ class TorchBackend(Backend): def argmax(self, a, axis=None): return torch.argmax(a, dim=axis) + def argmin(self, a, axis=None): + return torch.argmin(a, dim=axis) + def mean(self, a, axis=None): if axis is not None: return torch.mean(a, dim=axis) @@ -1580,8 +1734,11 @@ class TorchBackend(Backend): return torch.linspace(start, stop, num, dtype=torch.float64) def meshgrid(self, a, b): - X, Y = torch.meshgrid(a, b) - return X.T, Y.T + try: + return torch.meshgrid(a, b, indexing="xy") + except TypeError: + X, Y = torch.meshgrid(a, b) + return X.T, Y.T def diag(self, a, k=0): return torch.diag(a, diagonal=k) @@ -1659,8 +1816,11 @@ class TorchBackend(Backend): else: return a - def where(self, condition, x, y): - return torch.where(condition, x, y) + def where(self, condition, x=None, y=None): + if x is None and y is None: + return torch.where(condition) + else: + return torch.where(condition, x, y) def copy(self, a): return torch.clone(a) @@ -1718,6 +1878,28 @@ class TorchBackend(Backend): torch.cuda.empty_cache() return results + def solve(self, a, b): + return torch.linalg.solve(a, b) + + def trace(self, a): + return torch.trace(a) + + def inv(self, a): + return torch.linalg.inv(a) + + def sqrtm(self, a): + L, V = torch.linalg.eigh(a) + return (V * torch.sqrt(L)[None, :]) @ V.T + + def isfinite(self, a): + return torch.isfinite(a) + + def array_equal(self, a, b): + return torch.equal(a, b) + + def is_floating_point(self, a): + return a.dtype.is_floating_point + class CupyBackend(Backend): # pragma: no cover """ @@ -1741,10 +1923,12 @@ class CupyBackend(Backend): # pragma: no cover cp.array(1, dtype=cp.float64) ] - def to_numpy(self, a): + def _to_numpy(self, a): return cp.asnumpy(a) - def from_numpy(self, a, type_as=None): + def _from_numpy(self, a, type_as=None): + if isinstance(a, float): + a = np.array(a) if type_as is None: return cp.asarray(a) else: @@ -1884,6 +2068,9 @@ class CupyBackend(Backend): # pragma: no cover def argmax(self, a, axis=None): return cp.argmax(a, axis=axis) + def argmin(self, a, axis=None): + return cp.argmin(a, axis=axis) + def mean(self, a, axis=None): return cp.mean(a, axis=axis) @@ -1982,8 +2169,11 @@ class CupyBackend(Backend): # pragma: no cover else: return a - def where(self, condition, x, y): - return cp.where(condition, x, y) + def where(self, condition, x=None, y=None): + if x is None and y is None: + return cp.where(condition) + else: + return cp.where(condition, x, y) def copy(self, a): return a.copy() @@ -2035,6 +2225,28 @@ class CupyBackend(Backend): # pragma: no cover pinned_mempool.free_all_blocks() return results + def solve(self, a, b): + return cp.linalg.solve(a, b) + + def trace(self, a): + return cp.trace(a) + + def inv(self, a): + return cp.linalg.inv(a) + + def sqrtm(self, a): + L, V = cp.linalg.eigh(a) + return (V * self.sqrt(L)[None, :]) @ V.T + + def isfinite(self, a): + return cp.isfinite(a) + + def array_equal(self, a, b): + return cp.array_equal(a, b) + + def is_floating_point(self, a): + return a.dtype.kind == "f" + class TensorflowBackend(Backend): @@ -2060,13 +2272,16 @@ class TensorflowBackend(Backend): "To use TensorflowBackend, you need to activate the tensorflow " "numpy API. You can activate it by running: \n" "from tensorflow.python.ops.numpy_ops import np_config\n" - "np_config.enable_numpy_behavior()" + "np_config.enable_numpy_behavior()", + stacklevel=2 ) - def to_numpy(self, a): + def _to_numpy(self, a): return a.numpy() - def from_numpy(self, a, type_as=None): + def _from_numpy(self, a, type_as=None): + if isinstance(a, float): + a = np.array(a) if not isinstance(a, self.__type__): if type_as is None: return tf.convert_to_tensor(a) @@ -2208,6 +2423,9 @@ class TensorflowBackend(Backend): def argmax(self, a, axis=None): return tnp.argmax(a, axis=axis) + def argmin(self, a, axis=None): + return tnp.argmin(a, axis=axis) + def mean(self, a, axis=None): return tnp.mean(a, axis=axis) @@ -2309,8 +2527,11 @@ class TensorflowBackend(Backend): else: return a - def where(self, condition, x, y): - return tnp.where(condition, x, y) + def where(self, condition, x=None, y=None): + if x is None and y is None: + return tnp.where(condition) + else: + return tnp.where(condition, x, y) def copy(self, a): return tf.identity(a) @@ -2364,3 +2585,24 @@ class TensorflowBackend(Backend): results[key] = (t1 - t0) / n_runs return results + + def solve(self, a, b): + return tf.linalg.solve(a, b) + + def trace(self, a): + return tf.linalg.trace(a) + + def inv(self, a): + return tf.linalg.inv(a) + + def sqrtm(self, a): + return tf.linalg.sqrtm(a) + + def isfinite(self, a): + return tnp.isfinite(a) + + def array_equal(self, a, b): + return tnp.array_equal(a, b) + + def is_floating_point(self, a): + return a.dtype.is_floating diff --git a/ot/bregman.py b/ot/bregman.py index fc20175..c06af2f 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -2525,8 +2525,7 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, # geometric interpolation delta = nx.exp(alpha * nx.log(other) + (1 - alpha) * nx.log(inv_new)) K = projR(K, delta) - K0 = nx.dot(nx.diag(nx.dot(D.T, delta / inv_new)), K0) - + K0 = nx.dot(D.T, delta / inv_new)[:, None] * K0 err = nx.norm(nx.sum(K0, axis=1) - old) old = new if log: @@ -2656,16 +2655,16 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, classes = nx.unique(Ys[d]) # build the corresponding D_1 and D_2 matrices - Dtmp1 = nx.zeros((nbclasses, nsk), type_as=Xs[0]) - Dtmp2 = nx.zeros((nbclasses, nsk), type_as=Xs[0]) + Dtmp1 = np.zeros((nbclasses, nsk)) + Dtmp2 = np.zeros((nbclasses, nsk)) for c in classes: - nbelemperclass = nx.sum(Ys[d] == c) + nbelemperclass = float(nx.sum(Ys[d] == c)) if nbelemperclass != 0: - Dtmp1[int(c), Ys[d] == c] = 1. - Dtmp2[int(c), Ys[d] == c] = 1. / (nbelemperclass) - D1.append(Dtmp1) - D2.append(Dtmp2) + Dtmp1[int(c), nx.to_numpy(Ys[d] == c)] = 1. + Dtmp2[int(c), nx.to_numpy(Ys[d] == c)] = 1. / (nbelemperclass) + D1.append(nx.from_numpy(Dtmp1, type_as=Xs[0])) + D2.append(nx.from_numpy(Dtmp2, type_as=Xs[0])) # build the cost matrix and the Gibbs kernel Mtmp = dist(Xs[d], Xt, metric=metric) diff --git a/ot/da.py b/ot/da.py index 841f31a..0b9737e 100644 --- a/ot/da.py +++ b/ot/da.py @@ -12,12 +12,12 @@ Domain adaptation with optimal transport # License: MIT License import numpy as np -import scipy.linalg as linalg +from .backend import get_backend from .bregman import sinkhorn, jcpot_barycenter from .lp import emd from .utils import unif, dist, kernel, cost_normalization, label_normalization, laplacian, dots -from .utils import check_params, BaseEstimator +from .utils import list_to_array, check_params, BaseEstimator from .unbalanced import sinkhorn_unbalanced from .optim import cg from .optim import gcg @@ -60,13 +60,13 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, Parameters ---------- - a : np.ndarray (ns,) + a : array-like (ns,) samples weights in the source domain - labels_a : np.ndarray (ns,) + labels_a : array-like (ns,) labels of samples in the source domain - b : np.ndarray (nt,) + b : array-like (nt,) samples weights in the target domain - M : np.ndarray (ns,nt) + M : array-like (ns,nt) loss matrix reg : float Regularization term for entropic regularization >0 @@ -86,7 +86,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, Returns ------- - gamma : (ns, nt) ndarray + gamma : (ns, nt) array-like Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -111,26 +111,28 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, ot.optim.cg : General regularized OT """ + a, labels_a, b, M = list_to_array(a, labels_a, b, M) + nx = get_backend(a, labels_a, b, M) + p = 0.5 epsilon = 1e-3 indices_labels = [] - classes = np.unique(labels_a) + classes = nx.unique(labels_a) for c in classes: - idxc, = np.where(labels_a == c) + idxc, = nx.where(labels_a == c) indices_labels.append(idxc) - W = np.zeros(M.shape) - + W = nx.zeros(M.shape, type_as=M) for cpt in range(numItermax): Mreg = M + eta * W transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax, stopThr=stopInnerThr) # the transport has been computed. Check if classes are really # separated - W = np.ones(M.shape) + W = nx.ones(M.shape, type_as=M) for (i, c) in enumerate(classes): - majs = np.sum(transp[indices_labels[i]], axis=0) + majs = nx.sum(transp[indices_labels[i]], axis=0) majs = p * ((majs + epsilon) ** (p - 1)) W[indices_labels[i]] = majs @@ -174,13 +176,13 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, Parameters ---------- - a : np.ndarray (ns,) + a : array-like (ns,) samples weights in the source domain - labels_a : np.ndarray (ns,) + labels_a : array-like (ns,) labels of samples in the source domain - b : np.ndarray (nt,) + b : array-like (nt,) samples in the target domain - M : np.ndarray (ns,nt) + M : array-like (ns,nt) loss matrix reg : float Regularization term for entropic regularization >0 @@ -200,7 +202,7 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, Returns ------- - gamma : (ns, nt) ndarray + gamma : (ns, nt) array-like Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -222,22 +224,25 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, ot.optim.gcg : Generalized conditional gradient for OT problems """ - lstlab = np.unique(labels_a) + a, labels_a, b, M = list_to_array(a, labels_a, b, M) + nx = get_backend(a, labels_a, b, M) + + lstlab = nx.unique(labels_a) def f(G): res = 0 for i in range(G.shape[1]): for lab in lstlab: temp = G[labels_a == lab, i] - res += np.linalg.norm(temp) + res += nx.norm(temp) return res def df(G): - W = np.zeros(G.shape) + W = nx.zeros(G.shape, type_as=G) for i in range(G.shape[1]): for lab in lstlab: temp = G[labels_a == lab, i] - n = np.linalg.norm(temp) + n = nx.norm(temp) if n: W[labels_a == lab, i] = temp / n return W @@ -289,9 +294,9 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, Parameters ---------- - xs : np.ndarray (ns,d) + xs : array-like (ns,d) samples in the source domain - xt : np.ndarray (nt,d) + xt : array-like (nt,d) samples in the target domain mu : float,optional Weight for the linear OT loss (>0) @@ -315,9 +320,9 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, Returns ------- - gamma : (ns, nt) ndarray + gamma : (ns, nt) array-like Optimal transportation matrix for the given parameters - L : (d, d) ndarray + L : (d, d) array-like Linear mapping matrix ((:math:`d+1`, `d`) if bias) log : dict log dictionary return only if log==True in parameters @@ -336,13 +341,15 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, ot.optim.cg : General regularized OT """ + xs, xt = list_to_array(xs, xt) + nx = get_backend(xs, xt) ns, nt, d = xs.shape[0], xt.shape[0], xt.shape[1] if bias: - xs1 = np.hstack((xs, np.ones((ns, 1)))) - xstxs = xs1.T.dot(xs1) - Id = np.eye(d + 1) + xs1 = nx.concatenate((xs, nx.ones((ns, 1), type_as=xs)), axis=1) + xstxs = nx.dot(xs1.T, xs1) + Id = nx.eye(d + 1, type_as=xs) Id[-1] = 0 I0 = Id[:, :-1] @@ -350,8 +357,8 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, return x[:-1, :] else: xs1 = xs - xstxs = xs1.T.dot(xs1) - Id = np.eye(d) + xstxs = nx.dot(xs1.T, xs1) + Id = nx.eye(d, type_as=xs) I0 = Id def sel(x): @@ -360,7 +367,8 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, if log: log = {'err': []} - a, b = unif(ns), unif(nt) + a = unif(ns, type_as=xs) + b = unif(nt, type_as=xt) M = dist(xs, xt) * ns G = emd(a, b, M) @@ -368,23 +376,26 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, def loss(L, G): """Compute full loss""" - return np.sum((xs1.dot(L) - ns * G.dot(xt)) ** 2) + mu * \ - np.sum(G * M) + eta * np.sum(sel(L - I0) ** 2) + return ( + nx.sum((nx.dot(xs1, L) - ns * nx.dot(G, xt)) ** 2) + + mu * nx.sum(G * M) + + eta * nx.sum(sel(L - I0) ** 2) + ) def solve_L(G): """ solve L problem with fixed G (least square)""" - xst = ns * G.dot(xt) - return np.linalg.solve(xstxs + eta * Id, xs1.T.dot(xst) + eta * I0) + xst = ns * nx.dot(G, xt) + return nx.solve(xstxs + eta * Id, nx.dot(xs1.T, xst) + eta * I0) def solve_G(L, G0): """Update G with CG algorithm""" - xsi = xs1.dot(L) + xsi = nx.dot(xs1, L) def f(G): - return np.sum((xsi - ns * G.dot(xt)) ** 2) + return nx.sum((xsi - ns * nx.dot(G, xt)) ** 2) def df(G): - return -2 * ns * (xsi - ns * G.dot(xt)).dot(xt.T) + return -2 * ns * nx.dot(xsi - ns * nx.dot(G, xt), xt.T) G = cg(a, b, M, 1.0 / mu, f, df, G0=G0, numItermax=numInnerItermax, stopThr=stopInnerThr) @@ -481,9 +492,9 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', Parameters ---------- - xs : np.ndarray (ns,d) + xs : array-like (ns,d) samples in the source domain - xt : np.ndarray (nt,d) + xt : array-like (nt,d) samples in the target domain mu : float,optional Weight for the linear OT loss (>0) @@ -513,9 +524,9 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', Returns ------- - gamma : (ns, nt) ndarray + gamma : (ns, nt) array-like Optimal transportation matrix for the given parameters - L : (ns, d) ndarray + L : (ns, d) array-like Nonlinear mapping matrix ((:math:`n_s+1`, `d`) if bias) log : dict log dictionary return only if log==True in parameters @@ -534,15 +545,17 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', ot.optim.cg : General regularized OT """ + xs, xt = list_to_array(xs, xt) + nx = get_backend(xs, xt) ns, nt = xs.shape[0], xt.shape[0] K = kernel(xs, xs, method=kerneltype, sigma=sigma) if bias: - K1 = np.hstack((K, np.ones((ns, 1)))) - Id = np.eye(ns + 1) + K1 = nx.concatenate((K, nx.ones((ns, 1), type_as=xs)), axis=1) + Id = nx.eye(ns + 1, type_as=xs) Id[-1] = 0 - Kp = np.eye(ns + 1) + Kp = nx.eye(ns + 1, type_as=xs) Kp[:ns, :ns] = K # ls regu @@ -550,12 +563,12 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', # Kreg=I # RKHS regul - K0 = K1.T.dot(K1) + eta * Kp + K0 = nx.dot(K1.T, K1) + eta * Kp Kreg = Kp else: K1 = K - Id = np.eye(ns) + Id = nx.eye(ns, type_as=xs) # ls regul # K0 = K1.T.dot(K1)+eta*I @@ -568,7 +581,8 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', if log: log = {'err': []} - a, b = unif(ns), unif(nt) + a = unif(ns, type_as=xs) + b = unif(nt, type_as=xt) M = dist(xs, xt) * ns G = emd(a, b, M) @@ -576,28 +590,31 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', def loss(L, G): """Compute full loss""" - return np.sum((K1.dot(L) - ns * G.dot(xt)) ** 2) + mu * \ - np.sum(G * M) + eta * np.trace(L.T.dot(Kreg).dot(L)) + return ( + nx.sum((nx.dot(K1, L) - ns * nx.dot(G, xt)) ** 2) + + mu * nx.sum(G * M) + + eta * nx.trace(dots(L.T, Kreg, L)) + ) def solve_L_nobias(G): """ solve L problem with fixed G (least square)""" - xst = ns * G.dot(xt) - return np.linalg.solve(K0, xst) + xst = ns * nx.dot(G, xt) + return nx.solve(K0, xst) def solve_L_bias(G): """ solve L problem with fixed G (least square)""" - xst = ns * G.dot(xt) - return np.linalg.solve(K0, K1.T.dot(xst)) + xst = ns * nx.dot(G, xt) + return nx.solve(K0, nx.dot(K1.T, xst)) def solve_G(L, G0): """Update G with CG algorithm""" - xsi = K1.dot(L) + xsi = nx.dot(K1, L) def f(G): - return np.sum((xsi - ns * G.dot(xt)) ** 2) + return nx.sum((xsi - ns * nx.dot(G, xt)) ** 2) def df(G): - return -2 * ns * (xsi - ns * G.dot(xt)).dot(xt.T) + return -2 * ns * nx.dot(xsi - ns * nx.dot(G, xt), xt.T) G = cg(a, b, M, 1.0 / mu, f, df, G0=G0, numItermax=numInnerItermax, stopThr=stopInnerThr) @@ -681,15 +698,15 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, Parameters ---------- - xs : np.ndarray (ns,d) + xs : array-like (ns,d) samples in the source domain - xt : np.ndarray (nt,d) + xt : array-like (nt,d) samples in the target domain reg : float,optional regularization added to the diagonals of covariances (>0) - ws : np.ndarray (ns,1), optional + ws : array-like (ns,1), optional weights for the source samples - wt : np.ndarray (ns,1), optional + wt : array-like (ns,1), optional weights for the target samples bias: boolean, optional estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True) @@ -699,9 +716,9 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, Returns ------- - A : (d, d) ndarray + A : (d, d) array-like Linear operator - b : (1, d) ndarray + b : (1, d) array-like bias log : dict log dictionary return only if log==True in parameters @@ -719,36 +736,38 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, """ + xs, xt = list_to_array(xs, xt) + nx = get_backend(xs, xt) d = xs.shape[1] if bias: - mxs = xs.mean(0, keepdims=True) - mxt = xt.mean(0, keepdims=True) + mxs = nx.mean(xs, axis=0)[None, :] + mxt = nx.mean(xt, axis=0)[None, :] xs = xs - mxs xt = xt - mxt else: - mxs = np.zeros((1, d)) - mxt = np.zeros((1, d)) + mxs = nx.zeros((1, d), type_as=xs) + mxt = nx.zeros((1, d), type_as=xs) if ws is None: - ws = np.ones((xs.shape[0], 1)) / xs.shape[0] + ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0] if wt is None: - wt = np.ones((xt.shape[0], 1)) / xt.shape[0] + wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0] - Cs = (xs * ws).T.dot(xs) / ws.sum() + reg * np.eye(d) - Ct = (xt * wt).T.dot(xt) / wt.sum() + reg * np.eye(d) + Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(d, type_as=xs) + Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt) - Cs12 = linalg.sqrtm(Cs) - Cs_12 = linalg.inv(Cs12) + Cs12 = nx.sqrtm(Cs) + Cs_12 = nx.inv(Cs12) - M0 = linalg.sqrtm(Cs12.dot(Ct.dot(Cs12))) + M0 = nx.sqrtm(dots(Cs12, Ct, Cs12)) - A = Cs_12.dot(M0.dot(Cs_12)) + A = dots(Cs_12, M0, Cs_12) - b = mxt - mxs.dot(A) + b = mxt - nx.dot(mxs, A) if log: log = {} @@ -798,15 +817,15 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al Parameters ---------- - a : np.ndarray (ns,) + a : array-like (ns,) samples weights in the source domain - b : np.ndarray (nt,) + b : array-like (nt,) samples weights in the target domain - xs : np.ndarray (ns,d) + xs : array-like (ns,d) samples in the source domain - xt : np.ndarray (nt,d) + xt : array-like (nt,d) samples in the target domain - M : np.ndarray (ns,nt) + M : array-like (ns,nt) loss matrix sim : string, optional Type of similarity ('knn' or 'gauss') used to construct the Laplacian. @@ -834,7 +853,7 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al Returns ------- - gamma : (ns, nt) ndarray + gamma : (ns, nt) array-like Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -862,9 +881,12 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al raise ValueError( 'Similarity parameter should be an int or a float. Got {type} instead.'.format(type=type(sim_param).__name__)) + a, b, xs, xt, M = list_to_array(a, b, xs, xt, M) + nx = get_backend(a, b, xs, xt, M) + if sim == 'gauss': if sim_param is None: - sim_param = 1 / (2 * (np.mean(dist(xs, xs, 'sqeuclidean')) ** 2)) + sim_param = 1 / (2 * (nx.mean(dist(xs, xs, 'sqeuclidean')) ** 2)) sS = kernel(xs, xs, method=sim, sigma=sim_param) sT = kernel(xt, xt, method=sim, sigma=sim_param) @@ -874,9 +896,13 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al from sklearn.neighbors import kneighbors_graph - sS = kneighbors_graph(X=xs, n_neighbors=int(sim_param)).toarray() + sS = nx.from_numpy(kneighbors_graph( + X=nx.to_numpy(xs), n_neighbors=int(sim_param) + ).toarray(), type_as=xs) sS = (sS + sS.T) / 2 - sT = kneighbors_graph(xt, n_neighbors=int(sim_param)).toarray() + sT = nx.from_numpy(kneighbors_graph( + X=nx.to_numpy(xt), n_neighbors=int(sim_param) + ).toarray(), type_as=xt) sT = (sT + sT.T) / 2 else: raise ValueError('Unknown similarity type {sim}. Currently supported similarity types are "knn" and "gauss".'.format(sim=sim)) @@ -885,12 +911,14 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al lT = laplacian(sT) def f(G): - return alpha * np.trace(np.dot(xt.T, np.dot(G.T, np.dot(lS, np.dot(G, xt))))) \ - + (1 - alpha) * np.trace(np.dot(xs.T, np.dot(G, np.dot(lT, np.dot(G.T, xs))))) + return ( + alpha * nx.trace(dots(xt.T, G.T, lS, G, xt)) + + (1 - alpha) * nx.trace(dots(xs.T, G, lT, G.T, xs)) + ) ls2 = lS + lS.T lt2 = lT + lT.T - xt2 = np.dot(xt, xt.T) + xt2 = nx.dot(xt, xt.T) if reg == 'disp': Cs = -eta * alpha / xs.shape[0] * dots(ls2, xs, xt.T) @@ -898,8 +926,10 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al M = M + Cs + Ct def df(G): - return alpha * np.dot(ls2, np.dot(G, xt2))\ - + (1 - alpha) * np.dot(xs, np.dot(xs.T, np.dot(G, lt2))) + return ( + alpha * dots(ls2, G, xt2) + + (1 - alpha) * dots(xs, xs.T, G, lt2) + ) return cg(a, b, M, reg=eta, f=f, df=df, G0=None, numItermax=numItermax, numItermaxEmd=numInnerItermax, stopThr=stopThr, stopThr2=stopInnerThr, verbose=verbose, log=log) @@ -919,7 +949,7 @@ def distribution_estimation_uniform(X): The uniform distribution estimated from :math:`\mathbf{X}` """ - return unif(X.shape[0]) + return unif(X.shape[0], type_as=X) class BaseTransport(BaseEstimator): @@ -973,6 +1003,7 @@ class BaseTransport(BaseEstimator): self : object Returns self. """ + nx = self._get_backend(Xs, ys, Xt, yt) # check the necessary inputs parameters are here if check_params(Xs=Xs, Xt=Xt): @@ -984,14 +1015,14 @@ class BaseTransport(BaseEstimator): if (ys is not None) and (yt is not None): if self.limit_max != np.infty: - self.limit_max = self.limit_max * np.max(self.cost_) + self.limit_max = self.limit_max * nx.max(self.cost_) # assumes labeled source samples occupy the first rows # and labeled target samples occupy the first columns - classes = [c for c in np.unique(ys) if c != -1] + classes = [c for c in nx.unique(ys) if c != -1] for c in classes: - idx_s = np.where((ys != c) & (ys != -1)) - idx_t = np.where(yt == c) + idx_s = nx.where((ys != c) & (ys != -1)) + idx_t = nx.where(yt == c) # all the coefficients corresponding to a source sample # and a target sample : @@ -1062,23 +1093,24 @@ class BaseTransport(BaseEstimator): transp_Xs : array-like, shape (n_source_samples, n_features) The transport source samples. """ + nx = self.nx # check the necessary inputs parameters are here if check_params(Xs=Xs): - if np.array_equal(self.xs_, Xs): + if nx.array_equal(self.xs_, Xs): # perform standard barycentric mapping - transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None] + transp = self.coupling_ / nx.sum(self.coupling_, axis=1)[:, None] # set nans to 0 - transp[~ np.isfinite(transp)] = 0 + transp[~ nx.isfinite(transp)] = 0 # compute transported samples - transp_Xs = np.dot(transp, self.xt_) + transp_Xs = nx.dot(transp, self.xt_) else: # perform out of sample mapping - indices = np.arange(Xs.shape[0]) + indices = nx.arange(Xs.shape[0]) batch_ind = [ indices[i:i + batch_size] for i in range(0, len(indices), batch_size)] @@ -1087,20 +1119,20 @@ class BaseTransport(BaseEstimator): for bi in batch_ind: # get the nearest neighbor in the source domain D0 = dist(Xs[bi], self.xs_) - idx = np.argmin(D0, axis=1) + idx = nx.argmin(D0, axis=1) # transport the source samples - transp = self.coupling_ / np.sum( - self.coupling_, 1)[:, None] - transp[~ np.isfinite(transp)] = 0 - transp_Xs_ = np.dot(transp, self.xt_) + transp = self.coupling_ / nx.sum( + self.coupling_, axis=1)[:, None] + transp[~ nx.isfinite(transp)] = 0 + transp_Xs_ = nx.dot(transp, self.xt_) # define the transported points transp_Xs_ = transp_Xs_[idx, :] + Xs[bi] - self.xs_[idx, :] transp_Xs.append(transp_Xs_) - transp_Xs = np.concatenate(transp_Xs, axis=0) + transp_Xs = nx.concatenate(transp_Xs, axis=0) return transp_Xs @@ -1127,26 +1159,27 @@ class BaseTransport(BaseEstimator): International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. """ + nx = self.nx # check the necessary inputs parameters are here if check_params(ys=ys): - ysTemp = label_normalization(np.copy(ys)) - classes = np.unique(ysTemp) + ysTemp = label_normalization(nx.copy(ys)) + classes = nx.unique(ysTemp) n = len(classes) - D1 = np.zeros((n, len(ysTemp))) + D1 = nx.zeros((n, len(ysTemp)), type_as=self.coupling_) # perform label propagation - transp = self.coupling_ / np.sum(self.coupling_, 0, keepdims=True) + transp = self.coupling_ / nx.sum(self.coupling_, axis=0)[None, :] # set nans to 0 - transp[~ np.isfinite(transp)] = 0 + transp[~ nx.isfinite(transp)] = 0 for c in classes: D1[int(c), ysTemp == c] = 1 # compute propagated labels - transp_ys = np.dot(D1, transp) + transp_ys = nx.dot(D1, transp) return transp_ys.T @@ -1176,23 +1209,24 @@ class BaseTransport(BaseEstimator): transp_Xt : array-like, shape (n_source_samples, n_features) The transported target samples. """ + nx = self.nx # check the necessary inputs parameters are here if check_params(Xt=Xt): - if np.array_equal(self.xt_, Xt): + if nx.array_equal(self.xt_, Xt): # perform standard barycentric mapping - transp_ = self.coupling_.T / np.sum(self.coupling_, 0)[:, None] + transp_ = self.coupling_.T / nx.sum(self.coupling_, 0)[:, None] # set nans to 0 - transp_[~ np.isfinite(transp_)] = 0 + transp_[~ nx.isfinite(transp_)] = 0 # compute transported samples - transp_Xt = np.dot(transp_, self.xs_) + transp_Xt = nx.dot(transp_, self.xs_) else: # perform out of sample mapping - indices = np.arange(Xt.shape[0]) + indices = nx.arange(Xt.shape[0]) batch_ind = [ indices[i:i + batch_size] for i in range(0, len(indices), batch_size)] @@ -1200,20 +1234,20 @@ class BaseTransport(BaseEstimator): transp_Xt = [] for bi in batch_ind: D0 = dist(Xt[bi], self.xt_) - idx = np.argmin(D0, axis=1) + idx = nx.argmin(D0, axis=1) # transport the target samples - transp_ = self.coupling_.T / np.sum( + transp_ = self.coupling_.T / nx.sum( self.coupling_, 0)[:, None] - transp_[~ np.isfinite(transp_)] = 0 - transp_Xt_ = np.dot(transp_, self.xs_) + transp_[~ nx.isfinite(transp_)] = 0 + transp_Xt_ = nx.dot(transp_, self.xs_) # define the transported points transp_Xt_ = transp_Xt_[idx, :] + Xt[bi] - self.xt_[idx, :] transp_Xt.append(transp_Xt_) - transp_Xt = np.concatenate(transp_Xt, axis=0) + transp_Xt = nx.concatenate(transp_Xt, axis=0) return transp_Xt @@ -1230,26 +1264,27 @@ class BaseTransport(BaseEstimator): transp_ys : array-like, shape (n_source_samples, nb_classes) Estimated soft source labels. """ + nx = self.nx # check the necessary inputs parameters are here if check_params(yt=yt): - ytTemp = label_normalization(np.copy(yt)) - classes = np.unique(ytTemp) + ytTemp = label_normalization(nx.copy(yt)) + classes = nx.unique(ytTemp) n = len(classes) - D1 = np.zeros((n, len(ytTemp))) + D1 = nx.zeros((n, len(ytTemp)), type_as=self.coupling_) # perform label propagation - transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None] + transp = self.coupling_ / nx.sum(self.coupling_, 1)[:, None] # set nans to 0 - transp[~ np.isfinite(transp)] = 0 + transp[~ nx.isfinite(transp)] = 0 for c in classes: D1[int(c), ytTemp == c] = 1 # compute propagated samples - transp_ys = np.dot(D1, transp.T) + transp_ys = nx.dot(D1, transp.T) return transp_ys.T @@ -1330,14 +1365,15 @@ class LinearTransport(BaseTransport): self : object Returns self. """ + nx = self._get_backend(Xs, ys, Xt, yt) self.mu_s = self.distribution_estimation(Xs) self.mu_t = self.distribution_estimation(Xt) # coupling estimation returned_ = OT_mapping_linear(Xs, Xt, reg=self.reg, - ws=self.mu_s.reshape((-1, 1)), - wt=self.mu_t.reshape((-1, 1)), + ws=nx.reshape(self.mu_s, (-1, 1)), + wt=nx.reshape(self.mu_t, (-1, 1)), bias=self.bias, log=self.log) # deal with the value of log @@ -1348,8 +1384,8 @@ class LinearTransport(BaseTransport): self.log_ = dict() # re compute inverse mapping - self.A1_ = linalg.inv(self.A_) - self.B1_ = -self.B_.dot(self.A1_) + self.A1_ = nx.inv(self.A_) + self.B1_ = -nx.dot(self.B_, self.A1_) return self @@ -1378,10 +1414,11 @@ class LinearTransport(BaseTransport): transp_Xs : array-like, shape (n_source_samples, n_features) The transport source samples. """ + nx = self.nx # check the necessary inputs parameters are here if check_params(Xs=Xs): - transp_Xs = Xs.dot(self.A_) + self.B_ + transp_Xs = nx.dot(Xs, self.A_) + self.B_ return transp_Xs @@ -1411,10 +1448,11 @@ class LinearTransport(BaseTransport): transp_Xt : array-like, shape (n_source_samples, n_features) The transported target samples. """ + nx = self.nx # check the necessary inputs parameters are here if check_params(Xt=Xt): - transp_Xt = Xt.dot(self.A1_) + self.B1_ + transp_Xt = nx.dot(Xt, self.A1_) + self.B1_ return transp_Xt @@ -2112,6 +2150,7 @@ class MappingTransport(BaseEstimator): self : object Returns self """ + self._get_backend(Xs, ys, Xt, yt) # check the necessary inputs parameters are here if check_params(Xs=Xs, Xt=Xt): @@ -2158,19 +2197,20 @@ class MappingTransport(BaseEstimator): transp_Xs : array-like, shape (n_source_samples, n_features) The transport source samples. """ + nx = self.nx # check the necessary inputs parameters are here if check_params(Xs=Xs): - if np.array_equal(self.xs_, Xs): + if nx.array_equal(self.xs_, Xs): # perform standard barycentric mapping - transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None] + transp = self.coupling_ / nx.sum(self.coupling_, 1)[:, None] # set nans to 0 - transp[~ np.isfinite(transp)] = 0 + transp[~ nx.isfinite(transp)] = 0 # compute transported samples - transp_Xs = np.dot(transp, self.xt_) + transp_Xs = nx.dot(transp, self.xt_) else: if self.kernel == "gaussian": K = kernel(Xs, self.xs_, method=self.kernel, @@ -2178,8 +2218,10 @@ class MappingTransport(BaseEstimator): elif self.kernel == "linear": K = Xs if self.bias: - K = np.hstack((K, np.ones((Xs.shape[0], 1)))) - transp_Xs = K.dot(self.mapping_) + K = nx.concatenate( + [K, nx.ones((Xs.shape[0], 1), type_as=K)], axis=1 + ) + transp_Xs = nx.dot(K, self.mapping_) return transp_Xs @@ -2396,6 +2438,7 @@ class JCPOTTransport(BaseTransport): self : object Returns self. """ + self._get_backend(*Xs, *ys, Xt, yt) # check the necessary inputs parameters are here if check_params(Xs=Xs, Xt=Xt, ys=ys): @@ -2438,28 +2481,29 @@ class JCPOTTransport(BaseTransport): batch_size : int, optional (default=128) The batch size for out of sample inverse transform """ + nx = self.nx transp_Xs = [] # check the necessary inputs parameters are here if check_params(Xs=Xs): - if all([np.allclose(x, y) for x, y in zip(self.xs_, Xs)]): + if all([nx.allclose(x, y) for x, y in zip(self.xs_, Xs)]): # perform standard barycentric mapping for each source domain for coupling in self.coupling_: - transp = coupling / np.sum(coupling, 1)[:, None] + transp = coupling / nx.sum(coupling, 1)[:, None] # set nans to 0 - transp[~ np.isfinite(transp)] = 0 + transp[~ nx.isfinite(transp)] = 0 # compute transported samples - transp_Xs.append(np.dot(transp, self.xt_)) + transp_Xs.append(nx.dot(transp, self.xt_)) else: # perform out of sample mapping - indices = np.arange(Xs.shape[0]) + indices = nx.arange(Xs.shape[0]) batch_ind = [ indices[i:i + batch_size] for i in range(0, len(indices), batch_size)] @@ -2470,23 +2514,22 @@ class JCPOTTransport(BaseTransport): transp_Xs_ = [] # get the nearest neighbor in the sources domains - xs = np.concatenate(self.xs_, axis=0) - idx = np.argmin(dist(Xs[bi], xs), axis=1) + xs = nx.concatenate(self.xs_, axis=0) + idx = nx.argmin(dist(Xs[bi], xs), axis=1) # transport the source samples for coupling in self.coupling_: - transp = coupling / np.sum( - coupling, 1)[:, None] - transp[~ np.isfinite(transp)] = 0 - transp_Xs_.append(np.dot(transp, self.xt_)) + transp = coupling / nx.sum(coupling, 1)[:, None] + transp[~ nx.isfinite(transp)] = 0 + transp_Xs_.append(nx.dot(transp, self.xt_)) - transp_Xs_ = np.concatenate(transp_Xs_, axis=0) + transp_Xs_ = nx.concatenate(transp_Xs_, axis=0) # define the transported points transp_Xs_ = transp_Xs_[idx, :] + Xs[bi] - xs[idx, :] transp_Xs.append(transp_Xs_) - transp_Xs = np.concatenate(transp_Xs, axis=0) + transp_Xs = nx.concatenate(transp_Xs, axis=0) return transp_Xs @@ -2512,32 +2555,36 @@ class JCPOTTransport(BaseTransport): "Optimal transport for multi-source domain adaptation under target shift", International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. """ + nx = self.nx # check the necessary inputs parameters are here if check_params(ys=ys): - yt = np.zeros((len(np.unique(np.concatenate(ys))), self.xt_.shape[0])) + yt = nx.zeros( + (len(nx.unique(nx.concatenate(ys))), self.xt_.shape[0]), + type_as=ys[0] + ) for i in range(len(ys)): - ysTemp = label_normalization(np.copy(ys[i])) - classes = np.unique(ysTemp) + ysTemp = label_normalization(nx.copy(ys[i])) + classes = nx.unique(ysTemp) n = len(classes) ns = len(ysTemp) # perform label propagation - transp = self.coupling_[i] / np.sum(self.coupling_[i], 1)[:, None] + transp = self.coupling_[i] / nx.sum(self.coupling_[i], 1)[:, None] # set nans to 0 - transp[~ np.isfinite(transp)] = 0 + transp[~ nx.isfinite(transp)] = 0 if self.log: D1 = self.log_['D1'][i] else: - D1 = np.zeros((n, ns)) + D1 = nx.zeros((n, ns), type_as=transp) for c in classes: D1[int(c), ysTemp == c] = 1 # compute propagated labels - yt = yt + np.dot(D1, transp) / len(ys) + yt = yt + nx.dot(D1, transp) / len(ys) return yt.T @@ -2555,14 +2602,15 @@ class JCPOTTransport(BaseTransport): transp_ys : list of K array-like objects, shape K x (nk_source_samples, nb_classes) A list of estimated soft source labels """ + nx = self.nx # check the necessary inputs parameters are here if check_params(yt=yt): transp_ys = [] - ytTemp = label_normalization(np.copy(yt)) - classes = np.unique(ytTemp) + ytTemp = label_normalization(nx.copy(yt)) + classes = nx.unique(ytTemp) n = len(classes) - D1 = np.zeros((n, len(ytTemp))) + D1 = nx.zeros((n, len(ytTemp)), type_as=self.coupling_[0]) for c in classes: D1[int(c), ytTemp == c] = 1 @@ -2570,12 +2618,12 @@ class JCPOTTransport(BaseTransport): for i in range(len(self.xs_)): # perform label propagation - transp = self.coupling_[i] / np.sum(self.coupling_[i], 1)[:, None] + transp = self.coupling_[i] / nx.sum(self.coupling_[i], 1)[:, None] # set nans to 0 - transp[~ np.isfinite(transp)] = 0 + transp[~ nx.isfinite(transp)] = 0 # compute propagated labels - transp_ys.append(np.dot(D1, transp.T).T) + transp_ys.append(nx.dot(D1, transp.T).T) return transp_ys diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index d9b6fa9..abf7fe0 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -225,6 +225,13 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): from all compatible backends. But the algorithm uses the C++ CPU backend which can lead to copy overhead on GPU arrays. + .. note:: This function will cast the computed transport plan to the data type + of the provided input with the following priority: :math:`\mathbf{a}`, + then :math:`\mathbf{b}`, then :math:`\mathbf{M}` if marginals are not provided. + Casting to an integer tensor might result in a loss of precision. + If this behaviour is unwanted, please make sure to provide a + floating point input. + Uses the algorithm proposed in :ref:`[1] `. Parameters @@ -290,12 +297,16 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): a, b, M = list_to_array(a, b, M) a0, b0, M0 = a, b, M + if len(a0) != 0: + type_as = a0 + elif len(b0) != 0: + type_as = b0 + else: + type_as = M0 nx = get_backend(M0, a0, b0) # convert to numpy - M = nx.to_numpy(M) - a = nx.to_numpy(a) - b = nx.to_numpy(b) + M, a, b = nx.to_numpy(M, a, b) # ensure float64 a = np.asarray(a, dtype=np.float64) @@ -330,15 +341,23 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): u, v = estimate_dual_null_weights(u, v, a, b, M) result_code_string = check_result(result_code) + if not nx.is_floating_point(type_as): + warnings.warn( + "Input histogram consists of integer. The transport plan will be " + "casted accordingly, possibly resulting in a loss of precision. " + "If this behaviour is unwanted, please make sure your input " + "histogram consists of floating point elements.", + stacklevel=2 + ) if log: log = {} log['cost'] = cost - log['u'] = nx.from_numpy(u, type_as=a0) - log['v'] = nx.from_numpy(v, type_as=b0) + log['u'] = nx.from_numpy(u, type_as=type_as) + log['v'] = nx.from_numpy(v, type_as=type_as) log['warning'] = result_code_string log['result_code'] = result_code - return nx.from_numpy(G, type_as=M0), log - return nx.from_numpy(G, type_as=M0) + return nx.from_numpy(G, type_as=type_as), log + return nx.from_numpy(G, type_as=type_as) def emd2(a, b, M, processes=1, @@ -364,6 +383,14 @@ def emd2(a, b, M, processes=1, from all compatible backends. But the algorithm uses the C++ CPU backend which can lead to copy overhead on GPU arrays. + .. note:: This function will cast the computed transport plan and + transportation loss to the data type of the provided input with the + following priority: :math:`\mathbf{a}`, then :math:`\mathbf{b}`, + then :math:`\mathbf{M}` if marginals are not provided. + Casting to an integer tensor might result in a loss of precision. + If this behaviour is unwanted, please make sure to provide a + floating point input. + Uses the algorithm proposed in :ref:`[1] `. Parameters @@ -432,12 +459,16 @@ def emd2(a, b, M, processes=1, a, b, M = list_to_array(a, b, M) a0, b0, M0 = a, b, M + if len(a0) != 0: + type_as = a0 + elif len(b0) != 0: + type_as = b0 + else: + type_as = M0 nx = get_backend(M0, a0, b0) # convert to numpy - M = nx.to_numpy(M) - a = nx.to_numpy(a) - b = nx.to_numpy(b) + M, a, b = nx.to_numpy(M, a, b) a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) @@ -470,14 +501,22 @@ def emd2(a, b, M, processes=1, result_code_string = check_result(result_code) log = {} - G = nx.from_numpy(G, type_as=M0) + if not nx.is_floating_point(type_as): + warnings.warn( + "Input histogram consists of integer. The transport plan will be " + "casted accordingly, possibly resulting in a loss of precision. " + "If this behaviour is unwanted, please make sure your input " + "histogram consists of floating point elements.", + stacklevel=2 + ) + G = nx.from_numpy(G, type_as=type_as) if return_matrix: log['G'] = G - log['u'] = nx.from_numpy(u, type_as=a0) - log['v'] = nx.from_numpy(v, type_as=b0) + log['u'] = nx.from_numpy(u, type_as=type_as) + log['v'] = nx.from_numpy(v, type_as=type_as) log['warning'] = result_code_string log['result_code'] = result_code - cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0), + cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as), (a0, b0, M0), (log['u'], log['v'], G)) return [cost, log] else: @@ -491,10 +530,18 @@ def emd2(a, b, M, processes=1, if np.any(~asel) or np.any(~bsel): u, v = estimate_dual_null_weights(u, v, a, b, M) - G = nx.from_numpy(G, type_as=M0) - cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0), - (a0, b0, M0), (nx.from_numpy(u, type_as=a0), - nx.from_numpy(v, type_as=b0), G)) + if not nx.is_floating_point(type_as): + warnings.warn( + "Input histogram consists of integer. The transport plan will be " + "casted accordingly, possibly resulting in a loss of precision. " + "If this behaviour is unwanted, please make sure your input " + "histogram consists of floating point elements.", + stacklevel=2 + ) + G = nx.from_numpy(G, type_as=type_as) + cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as), + (a0, b0, M0), (nx.from_numpy(u, type_as=type_as), + nx.from_numpy(v, type_as=type_as), G)) check_result(result_code) return cost diff --git a/ot/optim.py b/ot/optim.py index f25e2c9..5a1d605 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -9,12 +9,19 @@ Generic solvers for regularized OT # License: MIT License import numpy as np -from scipy.optimize.linesearch import scalar_search_armijo +import warnings from .lp import emd from .bregman import sinkhorn -from ot.utils import list_to_array +from .utils import list_to_array from .backend import get_backend +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + from scipy.optimize import scalar_search_armijo + except ImportError: + from scipy.optimize.linesearch import scalar_search_armijo + # The corresponding scipy function does not work for matrices diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 15e180b..503cc1e 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -8,9 +8,9 @@ Regularized Unbalanced OT solvers from __future__ import division import warnings -import numpy as np -from scipy.special import logsumexp +from .backend import get_backend +from .utils import list_to_array # from .utils import unif, dist @@ -43,12 +43,12 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, Parameters ---------- - a : np.ndarray (dim_a,) + a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` - b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + b : array-like (dim_b,) or array-like (dim_b, n_hists) One or multiple unnormalized histograms of dimension `dim_b`. If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i` - M : np.ndarray (dim_a, dim_b) + M : array-like (dim_a, dim_b) loss matrix reg : float Entropy regularization term > 0 @@ -70,12 +70,12 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, Returns ------- if n_hists == 1: - - gamma : (dim_a, dim_b) ndarray + - gamma : (dim_a, dim_b) array-like Optimal transportation matrix for the given parameters - log : dict log dictionary returned only if `log` is `True` else: - - ot_distance : (n_hists,) ndarray + - ot_distance : (n_hists,) array-like the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` - log : dict log dictionary returned only if `log` is `True` @@ -172,12 +172,12 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', Parameters ---------- - a : np.ndarray (dim_a,) + a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` - b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + b : array-like (dim_b,) or array-like (dim_b, n_hists) One or multiple unnormalized histograms of dimension `dim_b`. If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i` - M : np.ndarray (dim_a, dim_b) + M : array-like (dim_a, dim_b) loss matrix reg : float Entropy regularization term > 0 @@ -198,7 +198,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', Returns ------- - ot_distance : (n_hists,) ndarray + ot_distance : (n_hists,) array-like the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` log : dict log dictionary returned only if `log` is `True` @@ -239,9 +239,10 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epslilon scaling :ref:`[9, 10] ` """ - b = np.asarray(b, dtype=np.float64) + b = list_to_array(b) if len(b.shape) < 2: b = b[:, None] + if method.lower() == 'sinkhorn': return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=numItermax, @@ -291,12 +292,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, Parameters ---------- - a : np.ndarray (dim_a,) + a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` - b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + b : array-like (dim_b,) or array-like (dim_b, n_hists) One or multiple unnormalized histograms of dimension `dim_b` If many, compute all the OT distances (a, b_i) - M : np.ndarray (dim_a, dim_b) + M : array-like (dim_a, dim_b) loss matrix reg : float Entropy regularization term > 0 @@ -315,12 +316,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, Returns ------- if n_hists == 1: - - gamma : (dim_a, dim_b) ndarray + - gamma : (dim_a, dim_b) array-like Optimal transportation matrix for the given parameters - log : dict log dictionary returned only if `log` is `True` else: - - ot_distance : (n_hists,) ndarray + - ot_distance : (n_hists,) array-like the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` - log : dict log dictionary returned only if `log` is `True` @@ -354,17 +355,15 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, ot.optim.cg : General regularized OT """ - - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + M, a, b = list_to_array(M, a, b) + nx = get_backend(M, a, b) dim_a, dim_b = M.shape if len(a) == 0: - a = np.ones(dim_a, dtype=np.float64) / dim_a + a = nx.ones(dim_a, type_as=M) / dim_a if len(b) == 0: - b = np.ones(dim_b, dtype=np.float64) / dim_b + b = nx.ones(dim_b, type_as=M) / dim_b if len(b.shape) > 1: n_hists = b.shape[1] @@ -377,17 +376,14 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, # we assume that no distances are null except those of the diagonal of # distances if n_hists: - u = np.ones((dim_a, 1)) / dim_a - v = np.ones((dim_b, n_hists)) / dim_b + u = nx.ones((dim_a, 1), type_as=M) / dim_a + v = nx.ones((dim_b, n_hists), type_as=M) / dim_b a = a.reshape(dim_a, 1) else: - u = np.ones(dim_a) / dim_a - v = np.ones(dim_b) / dim_b + u = nx.ones(dim_a, type_as=M) / dim_a + v = nx.ones(dim_b, type_as=M) / dim_b - # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) + K = nx.exp(M / (-reg)) fi = reg_m / (reg_m + reg) @@ -397,14 +393,14 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, uprev = u vprev = v - Kv = K.dot(v) + Kv = nx.dot(K, v) u = (a / Kv) ** fi - Ktu = K.T.dot(u) + Ktu = nx.dot(K.T, u) v = (b / Ktu) ** fi - if (np.any(Ktu == 0.) - or np.any(np.isnan(u)) or np.any(np.isnan(v)) - or np.any(np.isinf(u)) or np.any(np.isinf(v))): + if (nx.any(Ktu == 0.) + or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop warnings.warn('Numerical errors at iteration %s' % i) @@ -412,8 +408,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, v = vprev break - err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.) - err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.) + err_u = nx.max(nx.abs(u - uprev)) / max( + nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1. + ) + err_v = nx.max(nx.abs(v - vprev)) / max( + nx.max(nx.abs(v)), nx.max(nx.abs(vprev)), 1. + ) err = 0.5 * (err_u + err_v) if log: log['err'].append(err) @@ -426,11 +426,11 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, break if log: - log['logu'] = np.log(u + 1e-300) - log['logv'] = np.log(v + 1e-300) + log['logu'] = nx.log(u + 1e-300) + log['logv'] = nx.log(v + 1e-300) if n_hists: # return only loss - res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) + res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M) if log: return res, log else: @@ -475,12 +475,12 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 Parameters ---------- - a : np.ndarray (dim_a,) + a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` - b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + b : array-like (dim_b,) or array-like (dim_b, n_hists) One or multiple unnormalized histograms of dimension `dim_b`. If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i` - M : np.ndarray (dim_a, dim_b) + M : array-like (dim_a, dim_b) loss matrix reg : float Entropy regularization term > 0 @@ -501,12 +501,12 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 Returns ------- if n_hists == 1: - - gamma : (dim_a, dim_b) ndarray + - gamma : (dim_a, dim_b) array-like Optimal transportation matrix for the given parameters - log : dict log dictionary returned only if `log` is `True` else: - - ot_distance : (n_hists,) ndarray + - ot_distance : (n_hists,) array-like the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` - log : dict log dictionary returned only if `log` is `True` @@ -538,17 +538,15 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 ot.optim.cg : General regularized OT """ - - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + a, b, M = list_to_array(a, b, M) + nx = get_backend(M, a, b) dim_a, dim_b = M.shape if len(a) == 0: - a = np.ones(dim_a, dtype=np.float64) / dim_a + a = nx.ones(dim_a, type_as=M) / dim_a if len(b) == 0: - b = np.ones(dim_b, dtype=np.float64) / dim_b + b = nx.ones(dim_b, type_as=M) / dim_b if len(b.shape) > 1: n_hists = b.shape[1] @@ -561,56 +559,52 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 # we assume that no distances are null except those of the diagonal of # distances if n_hists: - u = np.ones((dim_a, n_hists)) / dim_a - v = np.ones((dim_b, n_hists)) / dim_b + u = nx.ones((dim_a, n_hists), type_as=M) / dim_a + v = nx.ones((dim_b, n_hists), type_as=M) / dim_b a = a.reshape(dim_a, 1) else: - u = np.ones(dim_a) / dim_a - v = np.ones(dim_b) / dim_b + u = nx.ones(dim_a, type_as=M) / dim_a + v = nx.ones(dim_b, type_as=M) / dim_b # print(reg) - # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) + K = nx.exp(-M / reg) fi = reg_m / (reg_m + reg) cpt = 0 err = 1. - alpha = np.zeros(dim_a) - beta = np.zeros(dim_b) + alpha = nx.zeros(dim_a, type_as=M) + beta = nx.zeros(dim_b, type_as=M) while (err > stopThr and cpt < numItermax): uprev = u vprev = v - Kv = K.dot(v) - f_alpha = np.exp(- alpha / (reg + reg_m)) - f_beta = np.exp(- beta / (reg + reg_m)) + Kv = nx.dot(K, v) + f_alpha = nx.exp(- alpha / (reg + reg_m)) + f_beta = nx.exp(- beta / (reg + reg_m)) if n_hists: f_alpha = f_alpha[:, None] f_beta = f_beta[:, None] u = ((a / (Kv + 1e-16)) ** fi) * f_alpha - Ktu = K.T.dot(u) + Ktu = nx.dot(K.T, u) v = ((b / (Ktu + 1e-16)) ** fi) * f_beta absorbing = False - if (u > tau).any() or (v > tau).any(): + if nx.any(u > tau) or nx.any(v > tau): absorbing = True if n_hists: - alpha = alpha + reg * np.log(np.max(u, 1)) - beta = beta + reg * np.log(np.max(v, 1)) + alpha = alpha + reg * nx.log(nx.max(u, 1)) + beta = beta + reg * nx.log(nx.max(v, 1)) else: - alpha = alpha + reg * np.log(np.max(u)) - beta = beta + reg * np.log(np.max(v)) - K = np.exp((alpha[:, None] + beta[None, :] - - M) / reg) - v = np.ones_like(v) - Kv = K.dot(v) - - if (np.any(Ktu == 0.) - or np.any(np.isnan(u)) or np.any(np.isnan(v)) - or np.any(np.isinf(u)) or np.any(np.isinf(v))): + alpha = alpha + reg * nx.log(nx.max(u)) + beta = beta + reg * nx.log(nx.max(v)) + K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg) + v = nx.ones(v.shape, type_as=v) + Kv = nx.dot(K, v) + + if (nx.any(Ktu == 0.) + or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop warnings.warn('Numerical errors at iteration %s' % cpt) @@ -620,8 +614,9 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 if (cpt % 10 == 0 and not absorbing) or cpt == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), - 1.) + err = nx.max(nx.abs(u - uprev)) / max( + nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1. + ) if log: log['err'].append(err) if verbose: @@ -636,25 +631,30 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 "Try a larger entropy `reg` or a lower mass `reg_m`." + "Or a larger absorption threshold `tau`.") if n_hists: - logu = alpha[:, None] / reg + np.log(u) - logv = beta[:, None] / reg + np.log(v) + logu = alpha[:, None] / reg + nx.log(u) + logv = beta[:, None] / reg + nx.log(v) else: - logu = alpha / reg + np.log(u) - logv = beta / reg + np.log(v) + logu = alpha / reg + nx.log(u) + logv = beta / reg + nx.log(v) if log: log['logu'] = logu log['logv'] = logv if n_hists: # return only loss - res = logsumexp(np.log(M + 1e-100)[:, :, None] + logu[:, None, :] + - logv[None, :, :] - M[:, :, None] / reg, axis=(0, 1)) - res = np.exp(res) + res = nx.logsumexp( + nx.log(M + 1e-100)[:, :, None] + + logu[:, None, :] + + logv[None, :, :] + - M[:, :, None] / reg, + axis=(0, 1) + ) + res = nx.exp(res) if log: return res, log else: return res else: # return OT matrix - ot_matrix = np.exp(logu[:, None] + logv[None, :] - M / reg) + ot_matrix = nx.exp(logu[:, None] + logv[None, :] - M / reg) if log: return ot_matrix, log else: @@ -683,9 +683,9 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, Parameters ---------- - A : np.ndarray (dim, n_hists) + A : array-like (dim, n_hists) `n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim` - M : np.ndarray (dim, dim) + M : array-like (dim, dim) ground metric matrix for OT. reg : float Entropy regularization term > 0 @@ -693,7 +693,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, Marginal relaxation term > 0 tau : float Stabilization threshold for log domain absorption. - weights : np.ndarray (n_hists,) optional + weights : array-like (n_hists,) optional Weight of each distribution (barycentric coodinates) If None, uniform weights are used. numItermax : int, optional @@ -708,7 +708,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, Returns ------- - a : (dim,) ndarray + a : (dim,) array-like Unbalanced Wasserstein barycenter log : dict log dictionary return only if log==True in parameters @@ -726,9 +726,12 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, """ + A, M = list_to_array(A, M) + nx = get_backend(A, M) + dim, n_hists = A.shape if weights is None: - weights = np.ones(n_hists) / n_hists + weights = nx.ones(n_hists, type_as=A) / n_hists else: assert(len(weights) == A.shape[1]) @@ -737,47 +740,43 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, fi = reg_m / (reg_m + reg) - u = np.ones((dim, n_hists)) / dim - v = np.ones((dim, n_hists)) / dim + u = nx.ones((dim, n_hists), type_as=A) / dim + v = nx.ones((dim, n_hists), type_as=A) / dim # print(reg) - # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) + K = nx.exp(-M / reg) fi = reg_m / (reg_m + reg) cpt = 0 err = 1. - alpha = np.zeros(dim) - beta = np.zeros(dim) - q = np.ones(dim) / dim + alpha = nx.zeros(dim, type_as=A) + beta = nx.zeros(dim, type_as=A) + q = nx.ones(dim, type_as=A) / dim for i in range(numItermax): - qprev = q.copy() - Kv = K.dot(v) - f_alpha = np.exp(- alpha / (reg + reg_m)) - f_beta = np.exp(- beta / (reg + reg_m)) + qprev = nx.copy(q) + Kv = nx.dot(K, v) + f_alpha = nx.exp(- alpha / (reg + reg_m)) + f_beta = nx.exp(- beta / (reg + reg_m)) f_alpha = f_alpha[:, None] f_beta = f_beta[:, None] u = ((A / (Kv + 1e-16)) ** fi) * f_alpha - Ktu = K.T.dot(u) + Ktu = nx.dot(K.T, u) q = (Ktu ** (1 - fi)) * f_beta - q = q.dot(weights) ** (1 / (1 - fi)) + q = nx.dot(q, weights) ** (1 / (1 - fi)) Q = q[:, None] v = ((Q / (Ktu + 1e-16)) ** fi) * f_beta absorbing = False - if (u > tau).any() or (v > tau).any(): + if nx.any(u > tau) or nx.any(v > tau): absorbing = True - alpha = alpha + reg * np.log(np.max(u, 1)) - beta = beta + reg * np.log(np.max(v, 1)) - K = np.exp((alpha[:, None] + beta[None, :] - - M) / reg) - v = np.ones_like(v) - Kv = K.dot(v) - if (np.any(Ktu == 0.) - or np.any(np.isnan(u)) or np.any(np.isnan(v)) - or np.any(np.isinf(u)) or np.any(np.isinf(v))): + alpha = alpha + reg * nx.log(nx.max(u, 1)) + beta = beta + reg * nx.log(nx.max(v, 1)) + K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg) + v = nx.ones(v.shape, type_as=v) + Kv = nx.dot(K, v) + if (nx.any(Ktu == 0.) + or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop warnings.warn('Numerical errors at iteration %s' % cpt) @@ -786,8 +785,9 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, if (i % 10 == 0 and not absorbing) or i == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err = abs(q - qprev).max() / max(abs(q).max(), - abs(qprev).max(), 1.) + err = nx.max(nx.abs(q - qprev)) / max( + nx.max(nx.abs(q)), nx.max(nx.abs(qprev)), 1. + ) if log: log['err'].append(err) if verbose: @@ -804,8 +804,8 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, "Or a larger absorption threshold `tau`.") if log: log['niter'] = i - log['logu'] = np.log(u + 1e-300) - log['logv'] = np.log(v + 1e-300) + log['logu'] = nx.log(u + 1e-300) + log['logv'] = nx.log(v + 1e-300) return q, log else: return q @@ -833,15 +833,15 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, Parameters ---------- - A : np.ndarray (dim, n_hists) + A : array-like (dim, n_hists) `n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim` - M : np.ndarray (dim, dim) + M : array-like (dim, dim) ground metric matrix for OT. reg : float Entropy regularization term > 0 reg_m: float Marginal relaxation term > 0 - weights : np.ndarray (n_hists,) optional + weights : array-like (n_hists,) optional Weight of each distribution (barycentric coodinates) If None, uniform weights are used. numItermax : int, optional @@ -856,7 +856,7 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, Returns ------- - a : (dim,) ndarray + a : (dim,) array-like Unbalanced Wasserstein barycenter log : dict log dictionary return only if log==True in parameters @@ -874,40 +874,43 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, """ + A, M = list_to_array(A, M) + nx = get_backend(A, M) + dim, n_hists = A.shape if weights is None: - weights = np.ones(n_hists) / n_hists + weights = nx.ones(n_hists, type_as=A) / n_hists else: assert(len(weights) == A.shape[1]) if log: log = {'err': []} - K = np.exp(- M / reg) + K = nx.exp(-M / reg) fi = reg_m / (reg_m + reg) - v = np.ones((dim, n_hists)) - u = np.ones((dim, 1)) - q = np.ones(dim) + v = nx.ones((dim, n_hists), type_as=A) + u = nx.ones((dim, 1), type_as=A) + q = nx.ones(dim, type_as=A) err = 1. for i in range(numItermax): - uprev = u.copy() - vprev = v.copy() - qprev = q.copy() + uprev = nx.copy(u) + vprev = nx.copy(v) + qprev = nx.copy(q) - Kv = K.dot(v) + Kv = nx.dot(K, v) u = (A / Kv) ** fi - Ktu = K.T.dot(u) - q = ((Ktu ** (1 - fi)).dot(weights)) + Ktu = nx.dot(K.T, u) + q = nx.dot(Ktu ** (1 - fi), weights) q = q ** (1 / (1 - fi)) Q = q[:, None] v = (Q / Ktu) ** fi - if (np.any(Ktu == 0.) - or np.any(np.isnan(u)) or np.any(np.isnan(v)) - or np.any(np.isinf(u)) or np.any(np.isinf(v))): + if (nx.any(Ktu == 0.) + or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop warnings.warn('Numerical errors at iteration %s' % i) @@ -916,8 +919,9 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, q = qprev break # compute change in barycenter - err = abs(q - qprev).max() - err /= max(abs(q).max(), abs(qprev).max(), 1.) + err = nx.max(nx.abs(q - qprev)) / max( + nx.max(nx.abs(q)), nx.max(nx.abs(qprev)), 1.0 + ) if log: log['err'].append(err) # if barycenter did not change + at least 10 iterations - stop @@ -932,8 +936,8 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, if log: log['niter'] = i - log['logu'] = np.log(u + 1e-300) - log['logv'] = np.log(v + 1e-300) + log['logu'] = nx.log(u + 1e-300) + log['logv'] = nx.log(v + 1e-300) return q, log else: return q @@ -961,15 +965,15 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, Parameters ---------- - A : np.ndarray (dim, n_hists) + A : array-like (dim, n_hists) `n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim` - M : np.ndarray (dim, dim) + M : array-like (dim, dim) ground metric matrix for OT. reg : float Entropy regularization term > 0 reg_m: float Marginal relaxation term > 0 - weights : np.ndarray (n_hists,) optional + weights : array-like (n_hists,) optional Weight of each distribution (barycentric coodinates) If None, uniform weights are used. numItermax : int, optional @@ -984,7 +988,7 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, Returns ------- - a : (dim,) ndarray + a : (dim,) array-like Unbalanced Wasserstein barycenter log : dict log dictionary return only if log==True in parameters diff --git a/ot/utils.py b/ot/utils.py index 725ca00..a23ce7e 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -15,7 +15,7 @@ from scipy.spatial.distance import cdist import sys import warnings from inspect import signature -from .backend import get_backend +from .backend import get_backend, Backend __time_tic_toc = time.time() @@ -51,7 +51,8 @@ def kernel(x1, x2, method='gaussian', sigma=1, **kwargs): def laplacian(x): r"""Compute Laplacian matrix""" - L = np.diag(np.sum(x, axis=0)) - x + nx = get_backend(x) + L = nx.diag(nx.sum(x, axis=0)) - x return L @@ -136,7 +137,7 @@ def unif(n, type_as=None): return np.ones((n,)) / n else: nx = get_backend(type_as) - return nx.ones((n,)) / n + return nx.ones((n,), type_as=type_as) / n def clean_zeros(a, b, M): @@ -296,7 +297,8 @@ def cost_normalization(C, norm=None): def dots(*args): r""" dots function for multiple matrix multiply """ - return reduce(np.dot, args) + nx = get_backend(*args) + return reduce(nx.dot, args) def label_normalization(y, start=0): @@ -314,8 +316,9 @@ def label_normalization(y, start=0): y : array-like, shape (`n1`, ) The input vector of labels normalized according to given start value. """ + nx = get_backend(y) - diff = np.min(np.unique(y)) - start + diff = nx.min(nx.unique(y)) - start if diff != 0: y -= diff return y @@ -482,6 +485,19 @@ class BaseEstimator(object): arguments (no ``*args`` or ``**kwargs``). """ + nx: Backend = None + + def _get_backend(self, *arrays): + nx = get_backend( + *[input_ for input_ in arrays if input_ is not None] + ) + if nx.__name__ in ("jax", "tf"): + raise TypeError( + """JAX or TF arrays have been received but domain + adaptation does not support those backend.""") + self.nx = nx + return nx + @classmethod def _get_param_names(cls): r"""Get parameter names for the estimator""" diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index 6a42cfe..20f307a 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -66,9 +66,7 @@ def test_wasserstein_1d(nx): rho_v = np.abs(rng.randn(n)) rho_v /= rho_v.sum() - xb = nx.from_numpy(x) - rho_ub = nx.from_numpy(rho_u) - rho_vb = nx.from_numpy(rho_v) + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v) # test 1 : wasserstein_1d should be close to scipy W_1 implementation np.testing.assert_almost_equal(wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1), @@ -98,9 +96,7 @@ def test_wasserstein_1d_type_devices(nx): for tp in nx.__type_list__: print(nx.dtype_device(tp)) - xb = nx.from_numpy(x, type_as=tp) - rho_ub = nx.from_numpy(rho_u, type_as=tp) - rho_vb = nx.from_numpy(rho_v, type_as=tp) + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1) @@ -122,17 +118,13 @@ def test_wasserstein_1d_device_tf(): # Check that everything stays on the CPU with tf.device("/CPU:0"): - xb = nx.from_numpy(x) - rho_ub = nx.from_numpy(rho_u) - rho_vb = nx.from_numpy(rho_v) + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v) res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1) nx.assert_same_dtype_device(xb, res) if len(tf.config.list_physical_devices('GPU')) > 0: # Check that everything happens on the GPU - xb = nx.from_numpy(x) - rho_ub = nx.from_numpy(rho_u) - rho_vb = nx.from_numpy(rho_v) + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v) res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1) nx.assert_same_dtype_device(xb, res) assert nx.dtype_device(res)[1].startswith("GPU") @@ -190,9 +182,7 @@ def test_emd1d_type_devices(nx): for tp in nx.__type_list__: print(nx.dtype_device(tp)) - xb = nx.from_numpy(x, type_as=tp) - rho_ub = nx.from_numpy(rho_u, type_as=tp) - rho_vb = nx.from_numpy(rho_v, type_as=tp) + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) emd = ot.emd_1d(xb, xb, rho_ub, rho_vb) emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb) @@ -214,9 +204,7 @@ def test_emd1d_device_tf(): # Check that everything stays on the CPU with tf.device("/CPU:0"): - xb = nx.from_numpy(x) - rho_ub = nx.from_numpy(rho_u) - rho_vb = nx.from_numpy(rho_v) + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v) emd = ot.emd_1d(xb, xb, rho_ub, rho_vb) emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb) nx.assert_same_dtype_device(xb, emd) @@ -224,9 +212,7 @@ def test_emd1d_device_tf(): if len(tf.config.list_physical_devices('GPU')) > 0: # Check that everything happens on the GPU - xb = nx.from_numpy(x) - rho_ub = nx.from_numpy(rho_u) - rho_vb = nx.from_numpy(rho_v) + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v) emd = ot.emd_1d(xb, xb, rho_ub, rho_vb) emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb) nx.assert_same_dtype_device(xb, emd) diff --git a/test/test_backend.py b/test/test_backend.py index 027c4cd..311c075 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -217,6 +217,8 @@ def test_empty_backend(): nx.zero_pad(M, v) with pytest.raises(NotImplementedError): nx.argmax(M) + with pytest.raises(NotImplementedError): + nx.argmin(M) with pytest.raises(NotImplementedError): nx.mean(M) with pytest.raises(NotImplementedError): @@ -264,12 +266,27 @@ def test_empty_backend(): nx.device_type(M) with pytest.raises(NotImplementedError): nx._bench(lambda x: x, M, n_runs=1) + with pytest.raises(NotImplementedError): + nx.solve(M, v) + with pytest.raises(NotImplementedError): + nx.trace(M) + with pytest.raises(NotImplementedError): + nx.inv(M) + with pytest.raises(NotImplementedError): + nx.sqrtm(M) + with pytest.raises(NotImplementedError): + nx.isfinite(M) + with pytest.raises(NotImplementedError): + nx.array_equal(M, M) + with pytest.raises(NotImplementedError): + nx.is_floating_point(M) def test_func_backends(nx): rnd = np.random.RandomState(0) M = rnd.randn(10, 3) + SquareM = rnd.randn(10, 10) v = rnd.randn(3) val = np.array([1.0]) @@ -288,6 +305,7 @@ def test_func_backends(nx): lst_name = [] Mb = nx.from_numpy(M) + SquareMb = nx.from_numpy(SquareM) vb = nx.from_numpy(v) val = nx.from_numpy(val) @@ -467,6 +485,10 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append('argmax') + A = nx.argmin(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('argmin') + A = nx.mean(Mb) lst_b.append(nx.to_numpy(A)) lst_name.append('mean') @@ -529,7 +551,11 @@ def test_func_backends(nx): A = nx.where(Mb >= nx.stack([nx.linspace(0, 1, 10)] * 3, axis=1), Mb, 0.0) lst_b.append(nx.to_numpy(A)) - lst_name.append('where') + lst_name.append('where (cond, x, y)') + + A = nx.where(nx.from_numpy(np.array([True, False]))) + lst_b.append(nx.to_numpy(nx.stack(A))) + lst_name.append('where (cond)') A = nx.copy(Mb) lst_b.append(nx.to_numpy(A)) @@ -550,15 +576,47 @@ def test_func_backends(nx): nx._bench(lambda x: x, M, n_runs=1) + A = nx.solve(SquareMb, Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('solve') + + A = nx.trace(SquareMb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('trace') + + A = nx.inv(SquareMb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('matrix inverse') + + A = nx.sqrtm(SquareMb.T @ SquareMb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("matrix square root") + + A = nx.concatenate([vb, nx.from_numpy(np.array([np.inf, np.nan]))], axis=0) + A = nx.isfinite(A) + lst_b.append(nx.to_numpy(A)) + lst_name.append("isfinite") + + assert not nx.array_equal(Mb, vb), "array_equal (shape)" + assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true" + assert not nx.array_equal( + Mb, Mb + nx.eye(*list(Mb.shape)) + ), "array_equal (elements) - expected false" + + assert nx.is_floating_point(Mb), "is_floating_point - expected true" + assert not nx.is_floating_point( + nx.from_numpy(np.array([0, 1, 2], dtype=int)) + ), "is_floating_point - expected false" + lst_tot.append(lst_b) lst_np = lst_tot[0] lst_b = lst_tot[1] for a1, a2, name in zip(lst_np, lst_b, lst_name): - if not np.allclose(a1, a2): - print('Assert fail on: ', name) - assert np.allclose(a1, a2, atol=1e-7) + np.testing.assert_allclose( + a2, a1, atol=1e-7, err_msg=f'ASSERT FAILED ON: {name}' + ) def test_random_backends(nx): diff --git a/test/test_bregman.py b/test/test_bregman.py index 1419f9b..6c37984 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -155,8 +155,7 @@ def test_sinkhorn_backends(nx): G = ot.sinkhorn(a, a, M, 1) - ab = nx.from_numpy(a) - M_nx = nx.from_numpy(M) + ab, M_nx = nx.from_numpy(a, M) Gb = ot.sinkhorn(ab, ab, M_nx, 1) @@ -176,8 +175,7 @@ def test_sinkhorn2_backends(nx): G = ot.sinkhorn(a, a, M, 1) - ab = nx.from_numpy(a) - M_nx = nx.from_numpy(M) + ab, M_nx = nx.from_numpy(a, M) Gb = ot.sinkhorn2(ab, ab, M_nx, 1) @@ -260,8 +258,7 @@ def test_sinkhorn_variants(nx): M = ot.dist(x, x) - ub = nx.from_numpy(u) - M_nx = nx.from_numpy(M) + ub, M_nx = nx.from_numpy(u, M) G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10) Gl = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) @@ -298,8 +295,7 @@ def test_sinkhorn_variants_dtype_device(nx, method): for tp in nx.__type_list__: print(nx.dtype_device(tp)) - ub = nx.from_numpy(u, type_as=tp) - Mb = nx.from_numpy(M, type_as=tp) + ub, Mb = nx.from_numpy(u, M, type_as=tp) Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10) @@ -318,8 +314,7 @@ def test_sinkhorn2_variants_dtype_device(nx, method): for tp in nx.__type_list__: print(nx.dtype_device(tp)) - ub = nx.from_numpy(u, type_as=tp) - Mb = nx.from_numpy(M, type_as=tp) + ub, Mb = nx.from_numpy(u, M, type_as=tp) lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10) @@ -337,8 +332,7 @@ def test_sinkhorn2_variants_device_tf(method): # Check that everything stays on the CPU with tf.device("/CPU:0"): - ub = nx.from_numpy(u) - Mb = nx.from_numpy(M) + ub, Mb = nx.from_numpy(u, M) Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10) lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10) nx.assert_same_dtype_device(Mb, Gb) @@ -346,8 +340,7 @@ def test_sinkhorn2_variants_device_tf(method): if len(tf.config.list_physical_devices('GPU')) > 0: # Check that everything happens on the GPU - ub = nx.from_numpy(u) - Mb = nx.from_numpy(M) + ub, Mb = nx.from_numpy(u, M) Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10) lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10) nx.assert_same_dtype_device(Mb, Gb) @@ -370,9 +363,7 @@ def test_sinkhorn_variants_multi_b(nx): M = ot.dist(x, x) - ub = nx.from_numpy(u) - bb = nx.from_numpy(b) - M_nx = nx.from_numpy(M) + ub, bb, M_nx = nx.from_numpy(u, b, M) G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10) Gl = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) @@ -400,9 +391,7 @@ def test_sinkhorn2_variants_multi_b(nx): M = ot.dist(x, x) - ub = nx.from_numpy(u) - bb = nx.from_numpy(b) - M_nx = nx.from_numpy(M) + ub, bb, M_nx = nx.from_numpy(u, b, M) G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10) Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) @@ -483,9 +472,7 @@ def test_barycenter(nx, method, verbose, warn): alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) - A_nx = nx.from_numpy(A) - M_nx = nx.from_numpy(M) - weights_nx = nx.from_numpy(weights) + A_nx, M_nx, weights_nx = nx.from_numpy(A, M, weights) reg = 1e-2 if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": @@ -523,9 +510,7 @@ def test_barycenter_debiased(nx, method, verbose, warn): alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) - A_nx = nx.from_numpy(A) - M_nx = nx.from_numpy(M) - weights_nx = nx.from_numpy(weights) + A_nx, M_nx, weights_nx = nx.from_numpy(A, M, weights) # wasserstein reg = 1e-2 @@ -594,9 +579,7 @@ def test_barycenter_stabilization(nx): alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) - A_nx = nx.from_numpy(A) - M_nx = nx.from_numpy(M) - weights_b = nx.from_numpy(weights) + A_nx, M_nx, weights_b = nx.from_numpy(A, M, weights) # wasserstein reg = 1e-2 @@ -697,11 +680,7 @@ def test_unmix(nx): M0 /= M0.max() h0 = ot.unif(2) - ab = nx.from_numpy(a) - Db = nx.from_numpy(D) - M_nx = nx.from_numpy(M) - M0b = nx.from_numpy(M0) - h0b = nx.from_numpy(h0) + ab, Db, M_nx, M0b, h0b = nx.from_numpy(a, D, M, M0, h0) # wasserstein reg = 1e-3 @@ -727,12 +706,7 @@ def test_empirical_sinkhorn(nx): M = ot.dist(X_s, X_t) M_m = ot.dist(X_s, X_t, metric='euclidean') - ab = nx.from_numpy(a) - bb = nx.from_numpy(b) - X_sb = nx.from_numpy(X_s) - X_tb = nx.from_numpy(X_t) - M_nx = nx.from_numpy(M, type_as=ab) - M_mb = nx.from_numpy(M_m, type_as=ab) + ab, bb, X_sb, X_tb, M_nx, M_mb = nx.from_numpy(a, b, X_s, X_t, M, M_m) G_sqe = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1)) sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1)) @@ -776,12 +750,7 @@ def test_lazy_empirical_sinkhorn(nx): M = ot.dist(X_s, X_t) M_m = ot.dist(X_s, X_t, metric='euclidean') - ab = nx.from_numpy(a) - bb = nx.from_numpy(b) - X_sb = nx.from_numpy(X_s) - X_tb = nx.from_numpy(X_t) - M_nx = nx.from_numpy(M, type_as=ab) - M_mb = nx.from_numpy(M_m, type_as=ab) + ab, bb, X_sb, X_tb, M_nx, M_mb = nx.from_numpy(a, b, X_s, X_t, M, M_m) f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True) f, g = nx.to_numpy(f), nx.to_numpy(g) @@ -825,19 +794,13 @@ def test_empirical_sinkhorn_divergence(nx): a = np.linspace(1, n, n) a /= a.sum() b = ot.unif(n) - X_s = np.reshape(np.arange(n), (n, 1)) - X_t = np.reshape(np.arange(0, n * 2, 2), (n, 1)) + X_s = np.reshape(np.arange(n, dtype=np.float64), (n, 1)) + X_t = np.reshape(np.arange(0, n * 2, 2, dtype=np.float64), (n, 1)) M = ot.dist(X_s, X_t) M_s = ot.dist(X_s, X_s) M_t = ot.dist(X_t, X_t) - ab = nx.from_numpy(a) - bb = nx.from_numpy(b) - X_sb = nx.from_numpy(X_s) - X_tb = nx.from_numpy(X_t) - M_nx = nx.from_numpy(M, type_as=ab) - M_sb = nx.from_numpy(M_s, type_as=ab) - M_tb = nx.from_numpy(M_t, type_as=ab) + ab, bb, X_sb, X_tb, M_nx, M_sb, M_tb = nx.from_numpy(a, b, X_s, X_t, M, M_s, M_t) emp_sinkhorn_div = nx.to_numpy(ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb)) sinkhorn_div = nx.to_numpy( @@ -872,9 +835,7 @@ def test_stabilized_vs_sinkhorn_multidim(nx): M /= np.median(M) epsilon = 0.1 - ab = nx.from_numpy(a) - bb = nx.from_numpy(b) - M_nx = nx.from_numpy(M, type_as=ab) + ab, bb, M_nx = nx.from_numpy(a, b, M) G_np, _ = ot.bregman.sinkhorn(a, b, M, reg=epsilon, method="sinkhorn", log=True) G, log = ot.bregman.sinkhorn(ab, bb, M_nx, reg=epsilon, @@ -936,9 +897,7 @@ def test_screenkhorn(nx): x = rng.randn(n, 2) M = ot.dist(x, x) - ab = nx.from_numpy(a) - bb = nx.from_numpy(b) - M_nx = nx.from_numpy(M, type_as=ab) + ab, bb, M_nx = nx.from_numpy(a, b, M) # sinkhorn G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-1)) diff --git a/test/test_da.py b/test/test_da.py index 9f2bb50..4bf0ab1 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -19,7 +19,32 @@ except ImportError: nosklearn = True -def test_sinkhorn_lpl1_transport_class(): +def test_class_jax_tf(): + backends = [] + from ot.backend import jax, tf + if jax: + backends.append(ot.backend.JaxBackend()) + if tf: + backends.append(ot.backend.TensorflowBackend()) + + for nx in backends: + ns = 150 + nt = 200 + + Xs, ys = make_data_classif('3gauss', ns) + Xt, yt = make_data_classif('3gauss2', nt) + + Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) + + otda = ot.da.SinkhornLpl1Transport() + + with pytest.raises(TypeError): + otda.fit(Xs=Xs, ys=ys, Xt=Xt) + + +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_sinkhorn_lpl1_transport_class(nx): """test_sinkhorn_transport """ @@ -29,6 +54,8 @@ def test_sinkhorn_lpl1_transport_class(): Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) + Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) + otda = ot.da.SinkhornLpl1Transport() # test its computed @@ -44,15 +71,15 @@ def test_sinkhorn_lpl1_transport_class(): mu_s = unif(ns) mu_t = unif(nt) assert_allclose( - np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3) assert_allclose( - np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3) # test transform transp_Xs = otda.transform(Xs=Xs) assert_equal(transp_Xs.shape, Xs.shape) - Xs_new, _ = make_data_classif('3gauss', ns + 1) + Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0]) transp_Xs_new = otda.transform(Xs_new) # check that the oos method is working @@ -62,7 +89,7 @@ def test_sinkhorn_lpl1_transport_class(): transp_Xt = otda.inverse_transform(Xt=Xt) assert_equal(transp_Xt.shape, Xt.shape) - Xt_new, _ = make_data_classif('3gauss2', nt + 1) + Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0]) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) # check that the oos method is working @@ -85,24 +112,26 @@ def test_sinkhorn_lpl1_transport_class(): # test unsupervised vs semi-supervised mode otda_unsup = ot.da.SinkhornLpl1Transport() otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt) - n_unsup = np.sum(otda_unsup.cost_) + n_unsup = nx.sum(otda_unsup.cost_) otda_semi = ot.da.SinkhornLpl1Transport() otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) - n_semisup = np.sum(otda_semi.cost_) + n_semisup = nx.sum(otda_semi.cost_) # check that the cost matrix norms are indeed different assert n_unsup != n_semisup, "semisupervised mode not working" # check that the coupling forbids mass transport between labeled source # and labeled target samples - mass_semi = np.sum( + mass_semi = nx.sum( otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]) assert mass_semi == 0, "semisupervised mode not working" -def test_sinkhorn_l1l2_transport_class(): +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_sinkhorn_l1l2_transport_class(nx): """test_sinkhorn_transport """ @@ -112,6 +141,8 @@ def test_sinkhorn_l1l2_transport_class(): Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) + Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) + otda = ot.da.SinkhornL1l2Transport() # test its computed @@ -128,15 +159,15 @@ def test_sinkhorn_l1l2_transport_class(): mu_s = unif(ns) mu_t = unif(nt) assert_allclose( - np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3) assert_allclose( - np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3) # test transform transp_Xs = otda.transform(Xs=Xs) assert_equal(transp_Xs.shape, Xs.shape) - Xs_new, _ = make_data_classif('3gauss', ns + 1) + Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0]) transp_Xs_new = otda.transform(Xs_new) # check that the oos method is working @@ -156,7 +187,7 @@ def test_sinkhorn_l1l2_transport_class(): assert_equal(transp_ys.shape[0], ys.shape[0]) assert_equal(transp_ys.shape[1], len(np.unique(yt))) - Xt_new, _ = make_data_classif('3gauss2', nt + 1) + Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0]) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) # check that the oos method is working @@ -169,22 +200,22 @@ def test_sinkhorn_l1l2_transport_class(): # test unsupervised vs semi-supervised mode otda_unsup = ot.da.SinkhornL1l2Transport() otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt) - n_unsup = np.sum(otda_unsup.cost_) + n_unsup = nx.sum(otda_unsup.cost_) otda_semi = ot.da.SinkhornL1l2Transport() otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) - n_semisup = np.sum(otda_semi.cost_) + n_semisup = nx.sum(otda_semi.cost_) # check that the cost matrix norms are indeed different assert n_unsup != n_semisup, "semisupervised mode not working" # check that the coupling forbids mass transport between labeled source # and labeled target samples - mass_semi = np.sum( + mass_semi = nx.sum( otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]) mass_semi = otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max] - assert_allclose(mass_semi, np.zeros_like(mass_semi), + assert_allclose(nx.to_numpy(mass_semi), np.zeros(list(mass_semi.shape)), rtol=1e-9, atol=1e-9) # check everything runs well with log=True @@ -193,7 +224,9 @@ def test_sinkhorn_l1l2_transport_class(): assert len(otda.log_.keys()) != 0 -def test_sinkhorn_transport_class(): +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_sinkhorn_transport_class(nx): """test_sinkhorn_transport """ @@ -203,6 +236,8 @@ def test_sinkhorn_transport_class(): Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) + Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) + otda = ot.da.SinkhornTransport() # test its computed @@ -219,15 +254,15 @@ def test_sinkhorn_transport_class(): mu_s = unif(ns) mu_t = unif(nt) assert_allclose( - np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3) assert_allclose( - np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3) # test transform transp_Xs = otda.transform(Xs=Xs) assert_equal(transp_Xs.shape, Xs.shape) - Xs_new, _ = make_data_classif('3gauss', ns + 1) + Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0]) transp_Xs_new = otda.transform(Xs_new) # check that the oos method is working @@ -247,7 +282,7 @@ def test_sinkhorn_transport_class(): assert_equal(transp_ys.shape[0], ys.shape[0]) assert_equal(transp_ys.shape[1], len(np.unique(yt))) - Xt_new, _ = make_data_classif('3gauss2', nt + 1) + Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0]) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) # check that the oos method is working @@ -260,19 +295,19 @@ def test_sinkhorn_transport_class(): # test unsupervised vs semi-supervised mode otda_unsup = ot.da.SinkhornTransport() otda_unsup.fit(Xs=Xs, Xt=Xt) - n_unsup = np.sum(otda_unsup.cost_) + n_unsup = nx.sum(otda_unsup.cost_) otda_semi = ot.da.SinkhornTransport() otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) - n_semisup = np.sum(otda_semi.cost_) + n_semisup = nx.sum(otda_semi.cost_) # check that the cost matrix norms are indeed different assert n_unsup != n_semisup, "semisupervised mode not working" # check that the coupling forbids mass transport between labeled source # and labeled target samples - mass_semi = np.sum( + mass_semi = nx.sum( otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]) assert mass_semi == 0, "semisupervised mode not working" @@ -282,7 +317,9 @@ def test_sinkhorn_transport_class(): assert len(otda.log_.keys()) != 0 -def test_unbalanced_sinkhorn_transport_class(): +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_unbalanced_sinkhorn_transport_class(nx): """test_sinkhorn_transport """ @@ -292,6 +329,8 @@ def test_unbalanced_sinkhorn_transport_class(): Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) + Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) + otda = ot.da.UnbalancedSinkhornTransport() # test its computed @@ -318,7 +357,7 @@ def test_unbalanced_sinkhorn_transport_class(): assert_equal(transp_ys.shape[0], ys.shape[0]) assert_equal(transp_ys.shape[1], len(np.unique(yt))) - Xs_new, _ = make_data_classif('3gauss', ns + 1) + Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0]) transp_Xs_new = otda.transform(Xs_new) # check that the oos method is working @@ -328,7 +367,7 @@ def test_unbalanced_sinkhorn_transport_class(): transp_Xt = otda.inverse_transform(Xt=Xt) assert_equal(transp_Xt.shape, Xt.shape) - Xt_new, _ = make_data_classif('3gauss2', nt + 1) + Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0]) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) # check that the oos method is working @@ -341,12 +380,12 @@ def test_unbalanced_sinkhorn_transport_class(): # test unsupervised vs semi-supervised mode otda_unsup = ot.da.SinkhornTransport() otda_unsup.fit(Xs=Xs, Xt=Xt) - n_unsup = np.sum(otda_unsup.cost_) + n_unsup = nx.sum(otda_unsup.cost_) otda_semi = ot.da.SinkhornTransport() otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) - n_semisup = np.sum(otda_semi.cost_) + n_semisup = nx.sum(otda_semi.cost_) # check that the cost matrix norms are indeed different assert n_unsup != n_semisup, "semisupervised mode not working" @@ -357,7 +396,9 @@ def test_unbalanced_sinkhorn_transport_class(): assert len(otda.log_.keys()) != 0 -def test_emd_transport_class(): +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_emd_transport_class(nx): """test_sinkhorn_transport """ @@ -367,6 +408,8 @@ def test_emd_transport_class(): Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) + Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) + otda = ot.da.EMDTransport() # test its computed @@ -382,15 +425,15 @@ def test_emd_transport_class(): mu_s = unif(ns) mu_t = unif(nt) assert_allclose( - np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3) assert_allclose( - np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3) # test transform transp_Xs = otda.transform(Xs=Xs) assert_equal(transp_Xs.shape, Xs.shape) - Xs_new, _ = make_data_classif('3gauss', ns + 1) + Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0]) transp_Xs_new = otda.transform(Xs_new) # check that the oos method is working @@ -410,7 +453,7 @@ def test_emd_transport_class(): assert_equal(transp_ys.shape[0], ys.shape[0]) assert_equal(transp_ys.shape[1], len(np.unique(yt))) - Xt_new, _ = make_data_classif('3gauss2', nt + 1) + Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0]) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) # check that the oos method is working @@ -423,28 +466,32 @@ def test_emd_transport_class(): # test unsupervised vs semi-supervised mode otda_unsup = ot.da.EMDTransport() otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt) - n_unsup = np.sum(otda_unsup.cost_) + n_unsup = nx.sum(otda_unsup.cost_) otda_semi = ot.da.EMDTransport() otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) - n_semisup = np.sum(otda_semi.cost_) + n_semisup = nx.sum(otda_semi.cost_) # check that the cost matrix norms are indeed different assert n_unsup != n_semisup, "semisupervised mode not working" # check that the coupling forbids mass transport between labeled source # and labeled target samples - mass_semi = np.sum( + mass_semi = nx.sum( otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]) mass_semi = otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max] # we need to use a small tolerance here, otherwise the test breaks - assert_allclose(mass_semi, np.zeros_like(mass_semi), + assert_allclose(nx.to_numpy(mass_semi), np.zeros(list(mass_semi.shape)), rtol=1e-2, atol=1e-2) -def test_mapping_transport_class(): +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +@pytest.mark.parametrize("kernel", ["linear", "gaussian"]) +@pytest.mark.parametrize("bias", ["unbiased", "biased"]) +def test_mapping_transport_class(nx, kernel, bias): """test_mapping_transport """ @@ -455,101 +502,29 @@ def test_mapping_transport_class(): Xt, yt = make_data_classif('3gauss2', nt) Xs_new, _ = make_data_classif('3gauss', ns + 1) - ########################################################################## - # kernel == linear mapping tests - ########################################################################## + Xs, Xt, Xs_new = nx.from_numpy(Xs, Xt, Xs_new) - # check computation and dimensions if bias == False - otda = ot.da.MappingTransport(kernel="linear", bias=False) + # Mapping tests + bias = bias == "biased" + otda = ot.da.MappingTransport(kernel=kernel, bias=bias) otda.fit(Xs=Xs, Xt=Xt) assert hasattr(otda, "coupling_") assert hasattr(otda, "mapping_") assert hasattr(otda, "log_") assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0]))) - assert_equal(otda.mapping_.shape, ((Xs.shape[1], Xt.shape[1]))) + S = Xs.shape[0] if kernel == "gaussian" else Xs.shape[1] # if linear + if bias: + S += 1 + assert_equal(otda.mapping_.shape, ((S, Xt.shape[1]))) # test margin constraints mu_s = unif(ns) mu_t = unif(nt) assert_allclose( - np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3) assert_allclose( - np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) - - # test transform - transp_Xs = otda.transform(Xs=Xs) - assert_equal(transp_Xs.shape, Xs.shape) - - transp_Xs_new = otda.transform(Xs_new) - - # check that the oos method is working - assert_equal(transp_Xs_new.shape, Xs_new.shape) - - # check computation and dimensions if bias == True - otda = ot.da.MappingTransport(kernel="linear", bias=True) - otda.fit(Xs=Xs, Xt=Xt) - assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0]))) - assert_equal(otda.mapping_.shape, ((Xs.shape[1] + 1, Xt.shape[1]))) - - # test margin constraints - mu_s = unif(ns) - mu_t = unif(nt) - assert_allclose( - np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) - assert_allclose( - np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) - - # test transform - transp_Xs = otda.transform(Xs=Xs) - assert_equal(transp_Xs.shape, Xs.shape) - - transp_Xs_new = otda.transform(Xs_new) - - # check that the oos method is working - assert_equal(transp_Xs_new.shape, Xs_new.shape) - - ########################################################################## - # kernel == gaussian mapping tests - ########################################################################## - - # check computation and dimensions if bias == False - otda = ot.da.MappingTransport(kernel="gaussian", bias=False) - otda.fit(Xs=Xs, Xt=Xt) - - assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0]))) - assert_equal(otda.mapping_.shape, ((Xs.shape[0], Xt.shape[1]))) - - # test margin constraints - mu_s = unif(ns) - mu_t = unif(nt) - assert_allclose( - np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) - assert_allclose( - np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) - - # test transform - transp_Xs = otda.transform(Xs=Xs) - assert_equal(transp_Xs.shape, Xs.shape) - - transp_Xs_new = otda.transform(Xs_new) - - # check that the oos method is working - assert_equal(transp_Xs_new.shape, Xs_new.shape) - - # check computation and dimensions if bias == True - otda = ot.da.MappingTransport(kernel="gaussian", bias=True) - otda.fit(Xs=Xs, Xt=Xt) - assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0]))) - assert_equal(otda.mapping_.shape, ((Xs.shape[0] + 1, Xt.shape[1]))) - - # test margin constraints - mu_s = unif(ns) - mu_t = unif(nt) - assert_allclose( - np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) - assert_allclose( - np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3) # test transform transp_Xs = otda.transform(Xs=Xs) @@ -561,29 +536,39 @@ def test_mapping_transport_class(): assert_equal(transp_Xs_new.shape, Xs_new.shape) # check everything runs well with log=True - otda = ot.da.MappingTransport(kernel="gaussian", log=True) + otda = ot.da.MappingTransport(kernel=kernel, bias=bias, log=True) otda.fit(Xs=Xs, Xt=Xt) assert len(otda.log_.keys()) != 0 + +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_mapping_transport_class_specific_seed(nx): # check that it does not crash when derphi is very close to 0 + ns = 20 + nt = 30 np.random.seed(39) Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) otda = ot.da.MappingTransport(kernel="gaussian", bias=False) - otda.fit(Xs=Xs, Xt=Xt) + otda.fit(Xs=nx.from_numpy(Xs), Xt=nx.from_numpy(Xt)) np.random.seed(None) -def test_linear_mapping(): +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_linear_mapping(nx): ns = 150 nt = 200 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) - A, b = ot.da.OT_mapping_linear(Xs, Xt) + Xsb, Xtb = nx.from_numpy(Xs, Xt) - Xst = Xs.dot(A) + b + A, b = ot.da.OT_mapping_linear(Xsb, Xtb) + + Xst = nx.to_numpy(nx.dot(Xsb, A) + b) Ct = np.cov(Xt.T) Cst = np.cov(Xst.T) @@ -591,22 +576,26 @@ def test_linear_mapping(): np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) -def test_linear_mapping_class(): +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_linear_mapping_class(nx): ns = 150 nt = 200 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) + Xsb, Xtb = nx.from_numpy(Xs, Xt) + otmap = ot.da.LinearTransport() - otmap.fit(Xs=Xs, Xt=Xt) + otmap.fit(Xs=Xsb, Xt=Xtb) assert hasattr(otmap, "A_") assert hasattr(otmap, "B_") assert hasattr(otmap, "A1_") assert hasattr(otmap, "B1_") - Xst = otmap.transform(Xs=Xs) + Xst = nx.to_numpy(otmap.transform(Xs=Xsb)) Ct = np.cov(Xt.T) Cst = np.cov(Xst.T) @@ -614,7 +603,9 @@ def test_linear_mapping_class(): np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) -def test_jcpot_transport_class(): +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_jcpot_transport_class(nx): """test_jcpot_transport """ @@ -627,6 +618,8 @@ def test_jcpot_transport_class(): Xt, yt = make_data_classif('3gauss2', nt) + Xs1, ys1, Xs2, ys2, Xt, yt = nx.from_numpy(Xs1, ys1, Xs2, ys2, Xt, yt) + Xs = [Xs1, Xs2] ys = [ys1, ys2] @@ -649,19 +642,24 @@ def test_jcpot_transport_class(): for i in range(len(Xs)): # test margin constraints w.r.t. uniform target weights for each coupling matrix assert_allclose( - np.sum(otda.coupling_[i], axis=0), mu_t, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_[i], axis=0)), mu_t, rtol=1e-3, atol=1e-3) # test margin constraints w.r.t. modified source weights for each source domain assert_allclose( - np.dot(otda.log_['D1'][i], np.sum(otda.coupling_[i], axis=1)), otda.proportions_, rtol=1e-3, - atol=1e-3) + nx.to_numpy( + nx.dot(otda.log_['D1'][i], nx.sum(otda.coupling_[i], axis=1)) + ), + nx.to_numpy(otda.proportions_), + rtol=1e-3, + atol=1e-3 + ) # test transform transp_Xs = otda.transform(Xs=Xs) [assert_equal(x.shape, y.shape) for x, y in zip(transp_Xs, Xs)] - Xs_new, _ = make_data_classif('3gauss', ns1 + 1) + Xs_new = nx.from_numpy(make_data_classif('3gauss', ns1 + 1)[0]) transp_Xs_new = otda.transform(Xs_new) # check that the oos method is working @@ -670,15 +668,16 @@ def test_jcpot_transport_class(): # check label propagation transp_yt = otda.transform_labels(ys) assert_equal(transp_yt.shape[0], yt.shape[0]) - assert_equal(transp_yt.shape[1], len(np.unique(ys))) + assert_equal(transp_yt.shape[1], len(np.unique(nx.to_numpy(*ys)))) # check inverse label propagation transp_ys = otda.inverse_transform_labels(yt) - [assert_equal(x.shape[0], y.shape[0]) for x, y in zip(transp_ys, ys)] - [assert_equal(x.shape[1], len(np.unique(y))) for x, y in zip(transp_ys, ys)] + for x, y in zip(transp_ys, ys): + assert_equal(x.shape[0], y.shape[0]) + assert_equal(x.shape[1], len(np.unique(nx.to_numpy(y)))) -def test_jcpot_barycenter(): +def test_jcpot_barycenter(nx): """test_jcpot_barycenter """ @@ -695,19 +694,23 @@ def test_jcpot_barycenter(): Xs1, ys1 = make_data_classif('2gauss_prop', ns1, nz=sigma, p=ps1) Xs2, ys2 = make_data_classif('2gauss_prop', ns2, nz=sigma, p=ps2) - Xt, yt = make_data_classif('2gauss_prop', nt, nz=sigma, p=pt) + Xt, _ = make_data_classif('2gauss_prop', nt, nz=sigma, p=pt) - Xs = [Xs1, Xs2] - ys = [ys1, ys2] + Xs1b, ys1b, Xs2b, ys2b, Xtb = nx.from_numpy(Xs1, ys1, Xs2, ys2, Xt) - prop = ot.bregman.jcpot_barycenter(Xs, ys, Xt, reg=.5, metric='sqeuclidean', + Xsb = [Xs1b, Xs2b] + ysb = [ys1b, ys2b] + + prop = ot.bregman.jcpot_barycenter(Xsb, ysb, Xtb, reg=.5, metric='sqeuclidean', numItermax=10000, stopThr=1e-9, verbose=False, log=False) - np.testing.assert_allclose(prop, [1 - pt, pt], rtol=1e-3, atol=1e-3) + np.testing.assert_allclose(nx.to_numpy(prop), [1 - pt, pt], rtol=1e-3, atol=1e-3) @pytest.mark.skipif(nosklearn, reason="No sklearn available") -def test_emd_laplace_class(): +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_emd_laplace_class(nx): """test_emd_laplace_transport """ ns = 150 @@ -716,6 +719,8 @@ def test_emd_laplace_class(): Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) + Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) + otda = ot.da.EMDLaplaceTransport(reg_lap=0.01, max_iter=1000, tol=1e-9, verbose=False, log=True) # test its computed @@ -732,15 +737,15 @@ def test_emd_laplace_class(): mu_t = unif(nt) assert_allclose( - np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3) assert_allclose( - np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3) # test transform transp_Xs = otda.transform(Xs=Xs) [assert_equal(x.shape, y.shape) for x, y in zip(transp_Xs, Xs)] - Xs_new, _ = make_data_classif('3gauss', ns + 1) + Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0]) transp_Xs_new = otda.transform(Xs_new) # check that the oos method is working @@ -750,7 +755,7 @@ def test_emd_laplace_class(): transp_Xt = otda.inverse_transform(Xt=Xt) assert_equal(transp_Xt.shape, Xt.shape) - Xt_new, _ = make_data_classif('3gauss2', nt + 1) + Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0]) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) # check that the oos method is working @@ -763,9 +768,9 @@ def test_emd_laplace_class(): # check label propagation transp_yt = otda.transform_labels(ys) assert_equal(transp_yt.shape[0], yt.shape[0]) - assert_equal(transp_yt.shape[1], len(np.unique(ys))) + assert_equal(transp_yt.shape[1], len(np.unique(nx.to_numpy(ys)))) # check inverse label propagation transp_ys = otda.inverse_transform_labels(yt) assert_equal(transp_ys.shape[0], ys.shape[0]) - assert_equal(transp_ys.shape[1], len(np.unique(yt))) + assert_equal(transp_ys.shape[1], len(np.unique(nx.to_numpy(yt)))) diff --git a/test/test_gromov.py b/test/test_gromov.py index 0dcf2da..12fd2b9 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -35,11 +35,7 @@ def test_gromov(nx): C1 /= C1.max() C2 /= C2.max() - C1b = nx.from_numpy(C1) - C2b = nx.from_numpy(C2) - pb = nx.from_numpy(p) - qb = nx.from_numpy(q) - G0b = nx.from_numpy(G0) + C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', G0=G0, verbose=True) Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True)) @@ -105,11 +101,7 @@ def test_gromov_dtype_device(nx): for tp in nx.__type_list__: print(nx.dtype_device(tp)) - C1b = nx.from_numpy(C1, type_as=tp) - C2b = nx.from_numpy(C2, type_as=tp) - pb = nx.from_numpy(p, type_as=tp) - qb = nx.from_numpy(q, type_as=tp) - G0b = nx.from_numpy(G0, type_as=tp) + C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0, type_as=tp) Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True) gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False) @@ -136,11 +128,7 @@ def test_gromov_device_tf(): # Check that everything stays on the CPU with tf.device("/CPU:0"): - C1b = nx.from_numpy(C1) - C2b = nx.from_numpy(C2) - pb = nx.from_numpy(p) - qb = nx.from_numpy(q) - G0b = nx.from_numpy(G0) + C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True) gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False) nx.assert_same_dtype_device(C1b, Gb) @@ -148,11 +136,7 @@ def test_gromov_device_tf(): if len(tf.config.list_physical_devices('GPU')) > 0: # Check that everything happens on the GPU - C1b = nx.from_numpy(C1) - C2b = nx.from_numpy(C2) - pb = nx.from_numpy(p) - qb = nx.from_numpy(q) - G0b = nx.from_numpy(G0b) + C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True) gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False) nx.assert_same_dtype_device(C1b, Gb) @@ -222,10 +206,7 @@ def test_entropic_gromov(nx): C1 /= C1.max() C2 /= C2.max() - C1b = nx.from_numpy(C1) - C2b = nx.from_numpy(C2) - pb = nx.from_numpy(p) - qb = nx.from_numpy(q) + C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q) G = ot.gromov.entropic_gromov_wasserstein( C1, C2, p, q, 'square_loss', epsilon=5e-4, verbose=True) @@ -285,10 +266,7 @@ def test_entropic_gromov_dtype_device(nx): for tp in nx.__type_list__: print(nx.dtype_device(tp)) - C1b = nx.from_numpy(C1, type_as=tp) - C2b = nx.from_numpy(C2, type_as=tp) - pb = nx.from_numpy(p, type_as=tp) - qb = nx.from_numpy(q, type_as=tp) + C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q, type_as=tp) Gb = ot.gromov.entropic_gromov_wasserstein( C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True @@ -320,10 +298,7 @@ def test_pointwise_gromov(nx): C1 /= C1.max() C2 /= C2.max() - C1b = nx.from_numpy(C1) - C2b = nx.from_numpy(C2) - pb = nx.from_numpy(p) - qb = nx.from_numpy(q) + C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q) def loss(x, y): return np.abs(x - y) @@ -381,10 +356,7 @@ def test_sampled_gromov(nx): C1 /= C1.max() C2 /= C2.max() - C1b = nx.from_numpy(C1) - C2b = nx.from_numpy(C2) - pb = nx.from_numpy(p) - qb = nx.from_numpy(q) + C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q) def loss(x, y): return np.abs(x - y) @@ -423,19 +395,15 @@ def test_gromov_barycenter(nx): n_samples = 3 p = ot.unif(n_samples) - C1b = nx.from_numpy(C1) - C2b = nx.from_numpy(C2) - p1b = nx.from_numpy(p1) - p2b = nx.from_numpy(p2) - pb = nx.from_numpy(p) + C1b, C2b, p1b, p2b, pb = nx.from_numpy(C1, C2, p1, p2, p) Cb = ot.gromov.gromov_barycenters( n_samples, [C1, C2], [p1, p2], p, [.5, .5], - 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42 + 'square_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42 ) Cbb = nx.to_numpy(ot.gromov.gromov_barycenters( n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], - 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42 + 'square_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42 )) np.testing.assert_allclose(Cb, Cbb, atol=1e-06) np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) @@ -443,15 +411,15 @@ def test_gromov_barycenter(nx): # test of gromov_barycenters with `log` on Cb_, err_ = ot.gromov.gromov_barycenters( n_samples, [C1, C2], [p1, p2], p, [.5, .5], - 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True + 'square_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42, log=True ) Cbb_, errb_ = ot.gromov.gromov_barycenters( n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], - 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True + 'square_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42, log=True ) Cbb_ = nx.to_numpy(Cbb_) np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06) - np.testing.assert_array_almost_equal(err_['err'], errb_['err']) + np.testing.assert_array_almost_equal(err_['err'], nx.to_numpy(*errb_['err'])) np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples)) Cb2 = ot.gromov.gromov_barycenters( @@ -468,15 +436,15 @@ def test_gromov_barycenter(nx): # test of gromov_barycenters with `log` on Cb2_, err2_ = ot.gromov.gromov_barycenters( n_samples, [C1, C2], [p1, p2], p, [.5, .5], - 'kl_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True + 'kl_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42, log=True ) Cb2b_, err2b_ = ot.gromov.gromov_barycenters( n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], - 'kl_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True + 'kl_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42, log=True ) Cb2b_ = nx.to_numpy(Cb2b_) np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06) - np.testing.assert_array_almost_equal(err2_['err'], err2_['err']) + np.testing.assert_array_almost_equal(err2_['err'], nx.to_numpy(*err2b_['err'])) np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples)) @@ -495,11 +463,7 @@ def test_gromov_entropic_barycenter(nx): n_samples = 2 p = ot.unif(n_samples) - C1b = nx.from_numpy(C1) - C2b = nx.from_numpy(C2) - p1b = nx.from_numpy(p1) - p2b = nx.from_numpy(p2) - pb = nx.from_numpy(p) + C1b, C2b, p1b, p2b, pb = nx.from_numpy(C1, C2, p1, p2, p) Cb = ot.gromov.entropic_gromov_barycenters( n_samples, [C1, C2], [p1, p2], p, [.5, .5], @@ -523,7 +487,7 @@ def test_gromov_entropic_barycenter(nx): ) Cbb_ = nx.to_numpy(Cbb_) np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06) - np.testing.assert_array_almost_equal(err_['err'], errb_['err']) + np.testing.assert_array_almost_equal(err_['err'], nx.to_numpy(*errb_['err'])) np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples)) Cb2 = ot.gromov.entropic_gromov_barycenters( @@ -548,7 +512,7 @@ def test_gromov_entropic_barycenter(nx): ) Cb2b_ = nx.to_numpy(Cb2b_) np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06) - np.testing.assert_array_almost_equal(err2_['err'], err2_['err']) + np.testing.assert_array_almost_equal(err2_['err'], nx.to_numpy(*err2b_['err'])) np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples)) @@ -578,12 +542,7 @@ def test_fgw(nx): M = ot.dist(ys, yt) M /= M.max() - Mb = nx.from_numpy(M) - C1b = nx.from_numpy(C1) - C2b = nx.from_numpy(C2) - pb = nx.from_numpy(p) - qb = nx.from_numpy(q) - G0b = nx.from_numpy(G0) + Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True) Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, G0=G0b, log=True) @@ -681,13 +640,7 @@ def test_fgw_barycenter(nx): n_samples = 3 p = ot.unif(n_samples) - ysb = nx.from_numpy(ys) - ytb = nx.from_numpy(yt) - C1b = nx.from_numpy(C1) - C2b = nx.from_numpy(C2) - p1b = nx.from_numpy(p1) - p2b = nx.from_numpy(p2) - pb = nx.from_numpy(p) + ysb, ytb, C1b, C2b, p1b, p2b, pb = nx.from_numpy(ys, yt, C1, C2, p1, p2, p) Xb, Cb = ot.gromov.fgw_barycenters( n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, fixed_structure=False, @@ -731,10 +684,8 @@ def test_gromov_wasserstein_linear_unmixing(nx): Cdict = np.stack([C1, C2]) p = ot.unif(n) - C1b = nx.from_numpy(C1) - C2b = nx.from_numpy(C2) - Cdictb = nx.from_numpy(Cdict) - pb = nx.from_numpy(p) + C1b, C2b, Cdictb, pb = nx.from_numpy(C1, C2, Cdict, p) + tol = 10**(-5) # Tests without regularization reg = 0. @@ -764,8 +715,8 @@ def test_gromov_wasserstein_linear_unmixing(nx): np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01) np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-06) np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-06) - np.testing.assert_allclose(reconstruction1, reconstruction1b, atol=1e-06) - np.testing.assert_allclose(reconstruction2, reconstruction2b, atol=1e-06) + np.testing.assert_allclose(reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06) + np.testing.assert_allclose(reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06) np.testing.assert_allclose(C1b_emb.shape, (n, n)) np.testing.assert_allclose(C2b_emb.shape, (n, n)) @@ -798,8 +749,8 @@ def test_gromov_wasserstein_linear_unmixing(nx): np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01) np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-06) np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-06) - np.testing.assert_allclose(reconstruction1, reconstruction1b, atol=1e-06) - np.testing.assert_allclose(reconstruction2, reconstruction2b, atol=1e-06) + np.testing.assert_allclose(reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06) + np.testing.assert_allclose(reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06) np.testing.assert_allclose(C1b_emb.shape, (n, n)) np.testing.assert_allclose(C2b_emb.shape, (n, n)) @@ -824,13 +775,14 @@ def test_gromov_wasserstein_dictionary_learning(nx): dataset_means = [C.mean() for C in Cs] np.random.seed(0) Cdict_init = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(n_atoms, shape, shape)) + if projection == 'nonnegative_symmetric': Cdict_init = 0.5 * (Cdict_init + Cdict_init.transpose((0, 2, 1))) Cdict_init[Cdict_init < 0.] = 0. - Csb = [nx.from_numpy(C) for C in Cs] - psb = [nx.from_numpy(p) for p in ps] - qb = nx.from_numpy(q) - Cdict_initb = nx.from_numpy(Cdict_init) + + Csb = nx.from_numpy(*Cs) + psb = nx.from_numpy(*ps) + qb, Cdict_initb = nx.from_numpy(q, Cdict_init) # Test: compare reconstruction error using initial dictionary and dictionary learned using this initialization # > Compute initial reconstruction of samples on this random dictionary without backend @@ -882,6 +834,7 @@ def test_gromov_wasserstein_dictionary_learning(nx): ) total_reconstruction_b += reconstruction + total_reconstruction_b = nx.to_numpy(total_reconstruction_b) np.testing.assert_array_less(total_reconstruction_b, initial_total_reconstruction) np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05) np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05) @@ -924,6 +877,7 @@ def test_gromov_wasserstein_dictionary_learning(nx): ) total_reconstruction_b_bis += reconstruction + total_reconstruction_b_bis = nx.to_numpy(total_reconstruction_b_bis) np.testing.assert_allclose(total_reconstruction_b_bis, total_reconstruction_b, atol=1e-05) np.testing.assert_allclose(Cdict_bis, nx.to_numpy(Cdictb_bis), atol=1e-03) @@ -969,6 +923,7 @@ def test_gromov_wasserstein_dictionary_learning(nx): ) total_reconstruction_b_bis2 += reconstruction + total_reconstruction_b_bis2 = nx.to_numpy(total_reconstruction_b_bis2) np.testing.assert_allclose(total_reconstruction_b_bis2, total_reconstruction_bis2, atol=1e-05) @@ -985,12 +940,8 @@ def test_fused_gromov_wasserstein_linear_unmixing(nx): Ydict = np.stack([F, F]) p = ot.unif(n) - C1b = nx.from_numpy(C1) - C2b = nx.from_numpy(C2) - Fb = nx.from_numpy(F) - Cdictb = nx.from_numpy(Cdict) - Ydictb = nx.from_numpy(Ydict) - pb = nx.from_numpy(p) + C1b, C2b, Fb, Cdictb, Ydictb, pb = nx.from_numpy(C1, C2, F, Cdict, Ydict, p) + # Tests without regularization reg = 0. @@ -1022,8 +973,8 @@ def test_fused_gromov_wasserstein_linear_unmixing(nx): np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-03) np.testing.assert_allclose(Y1_emb, nx.to_numpy(Y1b_emb), atol=1e-03) np.testing.assert_allclose(Y2_emb, nx.to_numpy(Y2b_emb), atol=1e-03) - np.testing.assert_allclose(reconstruction1, reconstruction1b, atol=1e-06) - np.testing.assert_allclose(reconstruction2, reconstruction2b, atol=1e-06) + np.testing.assert_allclose(reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06) + np.testing.assert_allclose(reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06) np.testing.assert_allclose(C1b_emb.shape, (n, n)) np.testing.assert_allclose(C2b_emb.shape, (n, n)) @@ -1058,8 +1009,8 @@ def test_fused_gromov_wasserstein_linear_unmixing(nx): np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-03) np.testing.assert_allclose(Y1_emb, nx.to_numpy(Y1b_emb), atol=1e-03) np.testing.assert_allclose(Y2_emb, nx.to_numpy(Y2b_emb), atol=1e-03) - np.testing.assert_allclose(reconstruction1, reconstruction1b, atol=1e-06) - np.testing.assert_allclose(reconstruction2, reconstruction2b, atol=1e-06) + np.testing.assert_allclose(reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06) + np.testing.assert_allclose(reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06) np.testing.assert_allclose(C1b_emb.shape, (n, n)) np.testing.assert_allclose(C2b_emb.shape, (n, n)) @@ -1093,12 +1044,10 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): dataset_feature_means = np.stack([Y.mean(axis=0) for Y in Ys]) Ydict_init = np.random.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(n_atoms, shape, 2)) - Csb = [nx.from_numpy(C) for C in Cs] - Ysb = [nx.from_numpy(Y) for Y in Ys] - psb = [nx.from_numpy(p) for p in ps] - qb = nx.from_numpy(q) - Cdict_initb = nx.from_numpy(Cdict_init) - Ydict_initb = nx.from_numpy(Ydict_init) + Csb = nx.from_numpy(*Cs) + Ysb = nx.from_numpy(*Ys) + psb = nx.from_numpy(*ps) + qb, Cdict_initb, Ydict_initb = nx.from_numpy(q, Cdict_init, Ydict_init) # Test: Compute initial reconstruction of samples on this random dictionary alpha = 0.5 @@ -1151,6 +1100,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): ) total_reconstruction_b += reconstruction + total_reconstruction_b = nx.to_numpy(total_reconstruction_b) np.testing.assert_array_less(total_reconstruction_b, initial_total_reconstruction) np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05) np.testing.assert_allclose(Cdict, nx.to_numpy(Cdictb), atol=1e-03) @@ -1192,6 +1142,8 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 ) total_reconstruction_b_bis += reconstruction + + total_reconstruction_b_bis = nx.to_numpy(total_reconstruction_b_bis) np.testing.assert_allclose(total_reconstruction_b_bis, total_reconstruction_b, atol=1e-05) # Test: without using adam optimizer, with log and verbose set to True @@ -1237,4 +1189,5 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): total_reconstruction_b_bis2 += reconstruction # > Compare results with/without backend + total_reconstruction_b_bis2 = nx.to_numpy(total_reconstruction_b_bis2) np.testing.assert_allclose(total_reconstruction_bis2, total_reconstruction_b_bis2, atol=1e-05) diff --git a/test/test_optim.py b/test/test_optim.py index 41f9cbe..67e9d13 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -32,9 +32,7 @@ def test_conditional_gradient(nx): def fb(G): return 0.5 * nx.sum(G ** 2) - ab = nx.from_numpy(a) - bb = nx.from_numpy(b) - Mb = nx.from_numpy(M, type_as=ab) + ab, bb, Mb = nx.from_numpy(a, b, M) reg = 1e-1 @@ -74,9 +72,7 @@ def test_conditional_gradient_itermax(nx): def fb(G): return 0.5 * nx.sum(G ** 2) - ab = nx.from_numpy(a) - bb = nx.from_numpy(b) - Mb = nx.from_numpy(M, type_as=ab) + ab, bb, Mb = nx.from_numpy(a, b, M) reg = 1e-1 @@ -118,9 +114,7 @@ def test_generalized_conditional_gradient(nx): reg1 = 1e-3 reg2 = 1e-1 - ab = nx.from_numpy(a) - bb = nx.from_numpy(b) - Mb = nx.from_numpy(M, type_as=ab) + ab, bb, Mb = nx.from_numpy(a, b, M) G, log = ot.optim.gcg(a, b, M, reg1, reg2, f, df, verbose=True, log=True) Gb, log = ot.optim.gcg(ab, bb, Mb, reg1, reg2, fb, df, verbose=True, log=True) @@ -142,9 +136,12 @@ def test_line_search_armijo(nx): pk = np.array([[-0.25, 0.25], [0.25, -0.25]]) gfk = np.array([[23.04273441, 23.0449082], [23.04273441, 23.0449082]]) old_fval = -123 + + xkb, pkb, gfkb = nx.from_numpy(xk, pk, gfk) + # Should not throw an exception and return 0. for alpha alpha, a, b = ot.optim.line_search_armijo( - lambda x: 1, nx.from_numpy(xk), nx.from_numpy(pk), nx.from_numpy(gfk), old_fval + lambda x: 1, xkb, pkb, gfkb, old_fval ) alpha_np, anp, bnp = ot.optim.line_search_armijo( lambda x: 1, xk, pk, gfk, old_fval diff --git a/test/test_ot.py b/test/test_ot.py index 3e2d845..bb258e2 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -47,8 +47,7 @@ def test_emd_backends(nx): G = ot.emd(a, a, M) - ab = nx.from_numpy(a) - Mb = nx.from_numpy(M) + ab, Mb = nx.from_numpy(a, M) Gb = ot.emd(ab, ab, Mb) @@ -68,8 +67,7 @@ def test_emd2_backends(nx): val = ot.emd2(a, a, M) - ab = nx.from_numpy(a) - Mb = nx.from_numpy(M) + ab, Mb = nx.from_numpy(a, M) valb = ot.emd2(ab, ab, Mb) @@ -90,8 +88,7 @@ def test_emd_emd2_types_devices(nx): for tp in nx.__type_list__: print(nx.dtype_device(tp)) - ab = nx.from_numpy(a, type_as=tp) - Mb = nx.from_numpy(M, type_as=tp) + ab, Mb = nx.from_numpy(a, M, type_as=tp) Gb = ot.emd(ab, ab, Mb) @@ -117,8 +114,7 @@ def test_emd_emd2_devices_tf(): # Check that everything stays on the CPU with tf.device("/CPU:0"): - ab = nx.from_numpy(a) - Mb = nx.from_numpy(M) + ab, Mb = nx.from_numpy(a, M) Gb = ot.emd(ab, ab, Mb) w = ot.emd2(ab, ab, Mb) nx.assert_same_dtype_device(Mb, Gb) @@ -126,8 +122,7 @@ def test_emd_emd2_devices_tf(): if len(tf.config.list_physical_devices('GPU')) > 0: # Check that everything happens on the GPU - ab = nx.from_numpy(a) - Mb = nx.from_numpy(M) + ab, Mb = nx.from_numpy(a, M) Gb = ot.emd(ab, ab, Mb) w = ot.emd2(ab, ab, Mb) nx.assert_same_dtype_device(Mb, Gb) @@ -310,8 +305,8 @@ def test_free_support_barycenter_backends(nx): X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init) - measures_locations2 = [nx.from_numpy(x) for x in measures_locations] - measures_weights2 = [nx.from_numpy(x) for x in measures_weights] + measures_locations2 = nx.from_numpy(*measures_locations) + measures_weights2 = nx.from_numpy(*measures_weights) X_init2 = nx.from_numpy(X_init) X2 = ot.lp.free_support_barycenter(measures_locations2, measures_weights2, X_init2) diff --git a/test/test_sliced.py b/test/test_sliced.py index 91e0961..08ab4fb 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -123,9 +123,7 @@ def test_sliced_backend(nx): n_projections = 20 - xb = nx.from_numpy(x) - yb = nx.from_numpy(y) - Pb = nx.from_numpy(P) + xb, yb, Pb = nx.from_numpy(x, y, P) val0 = ot.sliced_wasserstein_distance(x, y, projections=P) @@ -153,9 +151,7 @@ def test_sliced_backend_type_devices(nx): for tp in nx.__type_list__: print(nx.dtype_device(tp)) - xb = nx.from_numpy(x, type_as=tp) - yb = nx.from_numpy(y, type_as=tp) - Pb = nx.from_numpy(P, type_as=tp) + xb, yb, Pb = nx.from_numpy(x, y, P, type_as=tp) valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb) @@ -174,17 +170,13 @@ def test_sliced_backend_device_tf(): # Check that everything stays on the CPU with tf.device("/CPU:0"): - xb = nx.from_numpy(x) - yb = nx.from_numpy(y) - Pb = nx.from_numpy(P) + xb, yb, Pb = nx.from_numpy(x, y, P) valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb) nx.assert_same_dtype_device(xb, valb) if len(tf.config.list_physical_devices('GPU')) > 0: # Check that everything happens on the GPU - xb = nx.from_numpy(x) - yb = nx.from_numpy(y) - Pb = nx.from_numpy(P) + xb, yb, Pb = nx.from_numpy(x, y, P) valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb) nx.assert_same_dtype_device(xb, valb) assert nx.dtype_device(valb)[1].startswith("GPU") @@ -203,9 +195,7 @@ def test_max_sliced_backend(nx): n_projections = 20 - xb = nx.from_numpy(x) - yb = nx.from_numpy(y) - Pb = nx.from_numpy(P) + xb, yb, Pb = nx.from_numpy(x, y, P) val0 = ot.max_sliced_wasserstein_distance(x, y, projections=P) @@ -233,9 +223,7 @@ def test_max_sliced_backend_type_devices(nx): for tp in nx.__type_list__: print(nx.dtype_device(tp)) - xb = nx.from_numpy(x, type_as=tp) - yb = nx.from_numpy(y, type_as=tp) - Pb = nx.from_numpy(P, type_as=tp) + xb, yb, Pb = nx.from_numpy(x, y, P, type_as=tp) valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb) @@ -254,17 +242,13 @@ def test_max_sliced_backend_device_tf(): # Check that everything stays on the CPU with tf.device("/CPU:0"): - xb = nx.from_numpy(x) - yb = nx.from_numpy(y) - Pb = nx.from_numpy(P) + xb, yb, Pb = nx.from_numpy(x, y, P) valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb) nx.assert_same_dtype_device(xb, valb) if len(tf.config.list_physical_devices('GPU')) > 0: # Check that everything happens on the GPU - xb = nx.from_numpy(x) - yb = nx.from_numpy(y) - Pb = nx.from_numpy(P) + xb, yb, Pb = nx.from_numpy(x, y, P) valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb) nx.assert_same_dtype_device(xb, valb) assert nx.dtype_device(valb)[1].startswith("GPU") diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index e8349d1..db59504 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -9,11 +9,9 @@ import ot import pytest from ot.unbalanced import barycenter_unbalanced -from scipy.special import logsumexp - @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) -def test_unbalanced_convergence(method): +def test_unbalanced_convergence(nx, method): # test generalized sinkhorn for unbalanced OT n = 100 rng = np.random.RandomState(42) @@ -28,36 +26,51 @@ def test_unbalanced_convergence(method): epsilon = 1. reg_m = 1. + a, b, M = nx.from_numpy(a, b, M) + G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, reg_m=reg_m, method=method, log=True, verbose=True) - loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, - method=method, - verbose=True) + loss = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, epsilon, reg_m, method=method, verbose=True + )) # check fixed point equations # in log-domain fi = reg_m / (reg_m + epsilon) - logb = np.log(b + 1e-16) - loga = np.log(a + 1e-16) - logKtu = logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1) - logKv = logsumexp(log["logv"][None, :] - M / epsilon, axis=1) + logb = nx.log(b + 1e-16) + loga = nx.log(a + 1e-16) + logKtu = nx.logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1) + logKv = nx.logsumexp(log["logv"][None, :] - M / epsilon, axis=1) v_final = fi * (logb - logKtu) u_final = fi * (loga - logKv) np.testing.assert_allclose( - u_final, log["logu"], atol=1e-05) + nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05) np.testing.assert_allclose( - v_final, log["logv"], atol=1e-05) + nx.to_numpy(v_final), nx.to_numpy(log["logv"]), atol=1e-05) # check if sinkhorn_unbalanced2 returns the correct loss - np.testing.assert_allclose((G * M).sum(), loss, atol=1e-5) + np.testing.assert_allclose(nx.to_numpy(nx.sum(G * M)), loss, atol=1e-5) + + # check in case no histogram is provided + M_np = nx.to_numpy(M) + a_np, b_np = np.array([]), np.array([]) + a, b = nx.from_numpy(a_np, b_np) + + G = ot.unbalanced.sinkhorn_unbalanced( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, verbose=True + ) + G_np = ot.unbalanced.sinkhorn_unbalanced( + a_np, b_np, M_np, reg=epsilon, reg_m=reg_m, method=method, verbose=True + ) + np.testing.assert_allclose(G_np, nx.to_numpy(G)) @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) -def test_unbalanced_multiple_inputs(method): +def test_unbalanced_multiple_inputs(nx, method): # test generalized sinkhorn for unbalanced OT n = 100 rng = np.random.RandomState(42) @@ -72,6 +85,8 @@ def test_unbalanced_multiple_inputs(method): epsilon = 1. reg_m = 1. + a, b, M = nx.from_numpy(a, b, M) + loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, reg_m=reg_m, method=method, @@ -80,23 +95,24 @@ def test_unbalanced_multiple_inputs(method): # check fixed point equations # in log-domain fi = reg_m / (reg_m + epsilon) - logb = np.log(b + 1e-16) - loga = np.log(a + 1e-16)[:, None] - logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon, - axis=0) - logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) + logb = nx.log(b + 1e-16) + loga = nx.log(a + 1e-16)[:, None] + logKtu = nx.logsumexp( + log["logu"][:, None, :] - M[:, :, None] / epsilon, axis=0 + ) + logKv = nx.logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) v_final = fi * (logb - logKtu) u_final = fi * (loga - logKv) np.testing.assert_allclose( - u_final, log["logu"], atol=1e-05) + nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05) np.testing.assert_allclose( - v_final, log["logv"], atol=1e-05) + nx.to_numpy(v_final), nx.to_numpy(log["logv"]), atol=1e-05) assert len(loss) == b.shape[1] -def test_stabilized_vs_sinkhorn(): +def test_stabilized_vs_sinkhorn(nx): # test if stable version matches sinkhorn n = 100 @@ -112,19 +128,27 @@ def test_stabilized_vs_sinkhorn(): M /= np.median(M) epsilon = 0.1 reg_m = 1. - G, log = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg=epsilon, - method="sinkhorn_stabilized", - reg_m=reg_m, - log=True, - verbose=True) - G2, log2 = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, - method="sinkhorn", log=True) + + ab, bb, Mb = nx.from_numpy(a, b, M) + + G, _ = ot.unbalanced.sinkhorn_unbalanced2( + ab, bb, Mb, epsilon, reg_m, method="sinkhorn_stabilized", log=True + ) + G2, _ = ot.unbalanced.sinkhorn_unbalanced2( + ab, bb, Mb, epsilon, reg_m, method="sinkhorn", log=True + ) + G2_np, _ = ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, epsilon, reg_m, method="sinkhorn", log=True + ) + G = nx.to_numpy(G) + G2 = nx.to_numpy(G2) np.testing.assert_allclose(G, G2, atol=1e-5) + np.testing.assert_allclose(G2, G2_np, atol=1e-5) @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) -def test_unbalanced_barycenter(method): +def test_unbalanced_barycenter(nx, method): # test generalized sinkhorn for unbalanced OT barycenter n = 100 rng = np.random.RandomState(42) @@ -138,25 +162,29 @@ def test_unbalanced_barycenter(method): epsilon = 1. reg_m = 1. - q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, - method=method, log=True, verbose=True) + A, M = nx.from_numpy(A, M) + + q, log = barycenter_unbalanced( + A, M, reg=epsilon, reg_m=reg_m, method=method, log=True, verbose=True + ) # check fixed point equations fi = reg_m / (reg_m + epsilon) - logA = np.log(A + 1e-16) - logq = np.log(q + 1e-16)[:, None] - logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon, - axis=0) - logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) + logA = nx.log(A + 1e-16) + logq = nx.log(q + 1e-16)[:, None] + logKtu = nx.logsumexp( + log["logu"][:, None, :] - M[:, :, None] / epsilon, axis=0 + ) + logKv = nx.logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) v_final = fi * (logq - logKtu) u_final = fi * (logA - logKv) np.testing.assert_allclose( - u_final, log["logu"], atol=1e-05) + nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05) np.testing.assert_allclose( - v_final, log["logv"], atol=1e-05) + nx.to_numpy(v_final), nx.to_numpy(log["logv"]), atol=1e-05) -def test_barycenter_stabilized_vs_sinkhorn(): +def test_barycenter_stabilized_vs_sinkhorn(nx): # test generalized sinkhorn for unbalanced OT barycenter n = 100 rng = np.random.RandomState(42) @@ -170,21 +198,24 @@ def test_barycenter_stabilized_vs_sinkhorn(): epsilon = 0.5 reg_m = 10 - qstable, log = barycenter_unbalanced(A, M, reg=epsilon, - reg_m=reg_m, log=True, - tau=100, - method="sinkhorn_stabilized", - verbose=True - ) - q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, - method="sinkhorn", - log=True) + Ab, Mb = nx.from_numpy(A, M) - np.testing.assert_allclose( - q, qstable, atol=1e-05) + qstable, _ = barycenter_unbalanced( + Ab, Mb, reg=epsilon, reg_m=reg_m, log=True, tau=100, + method="sinkhorn_stabilized", verbose=True + ) + q, _ = barycenter_unbalanced( + Ab, Mb, reg=epsilon, reg_m=reg_m, method="sinkhorn", log=True + ) + q_np, _ = barycenter_unbalanced( + A, M, reg=epsilon, reg_m=reg_m, method="sinkhorn", log=True + ) + q, qstable = nx.to_numpy(q, qstable) + np.testing.assert_allclose(q, qstable, atol=1e-05) + np.testing.assert_allclose(q, q_np, atol=1e-05) -def test_wrong_method(): +def test_wrong_method(nx): n = 10 rng = np.random.RandomState(42) @@ -199,19 +230,20 @@ def test_wrong_method(): epsilon = 1. reg_m = 1. + a, b, M = nx.from_numpy(a, b, M) + with pytest.raises(ValueError): - ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, - reg_m=reg_m, - method='badmethod', - log=True, - verbose=True) + ot.unbalanced.sinkhorn_unbalanced( + a, b, M, reg=epsilon, reg_m=reg_m, method='badmethod', + log=True, verbose=True + ) with pytest.raises(ValueError): - ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, - method='badmethod', - verbose=True) + ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, epsilon, reg_m, method='badmethod', verbose=True + ) -def test_implemented_methods(): +def test_implemented_methods(nx): IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized'] TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling'] NOT_VALID_TOKENS = ['foo'] @@ -228,6 +260,9 @@ def test_implemented_methods(): M = ot.dist(x, x) epsilon = 1. reg_m = 1. + + a, b, M, A = nx.from_numpy(a, b, M, A) + for method in IMPLEMENTED_METHODS: ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m, method=method) diff --git a/test/test_weak.py b/test/test_weak.py index c4c3278..945efb1 100644 --- a/test/test_weak.py +++ b/test/test_weak.py @@ -45,9 +45,7 @@ def test_weak_ot_bakends(nx): G = ot.weak_optimal_transport(xs, xt, u, u) - xs2 = nx.from_numpy(xs) - xt2 = nx.from_numpy(xt) - u2 = nx.from_numpy(u) + xs2, xt2, u2 = nx.from_numpy(xs, xt, u) G2 = ot.weak_optimal_transport(xs2, xt2, u2, u2) -- cgit v1.2.3 From 0b223ff883fd73601984a92c31cb70d4aded16e8 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Thu, 7 Apr 2022 14:18:54 +0200 Subject: [MRG] Remove deprecated ot.gpu submodule (#361) * remove all cpu submodule and tests * speedup tests gromov --- README.md | 2 +- RELEASES.md | 1 + docs/source/quickstart.rst | 58 +++++++++++--- ot/gpu/__init__.py | 50 ------------ ot/gpu/bregman.py | 196 --------------------------------------------- ot/gpu/da.py | 144 --------------------------------- ot/gpu/utils.py | 101 ----------------------- test/test_gpu.py | 106 ------------------------ test/test_gromov.py | 129 ++++++++++++++--------------- 9 files changed, 113 insertions(+), 674 deletions(-) delete mode 100644 ot/gpu/__init__.py delete mode 100644 ot/gpu/bregman.py delete mode 100644 ot/gpu/da.py delete mode 100644 ot/gpu/utils.py delete mode 100644 test/test_gpu.py (limited to 'test/test_gromov.py') diff --git a/README.md b/README.md index 0c3bd19..2ace69c 100644 --- a/README.md +++ b/README.md @@ -185,7 +185,7 @@ The contributors to this library are * [Alexandre Gramfort](http://alexandre.gramfort.net/) (CI, documentation) * [Laetitia Chapel](http://people.irisa.fr/Laetitia.Chapel/) (Partial OT) * [Michael Perrot](http://perso.univ-st-etienne.fr/pem82055/) (Mapping estimation) -* [Léo Gautheron](https://github.com/aje) (GPU implementation) +* [Léo Gautheron](https://github.com/aje) (Initial GPU implementation) * [Nathalie Gayraud](https://www.linkedin.com/in/nathalie-t-h-gayraud/?ppe=1) (DA classes) * [Stanislas Chambon](https://slasnista.github.io/) (DA classes) * [Antoine Rolet](https://arolet.github.io/) (EMD solver debug) diff --git a/RELEASES.md b/RELEASES.md index 7d458f3..b54a84a 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -5,6 +5,7 @@ #### New features +- remode deprecated `ot.gpu` submodule (PR #361) - Update examples in the gallery (PR #359). - Add stochastic loss and OT plan computation for regularized OT and backend examples(PR #360). diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 09a362b..b4cc8ab 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -1028,15 +1028,6 @@ FAQ speedup can be obtained by using a GPU implementation since all operations are matrix/vector products. -4. **Using GPU fails with error: module 'ot' has no attribute 'gpu'** - - In order to limit import time and hard dependencies in POT. we do not import - some sub-modules automatically with :code:`import ot`. In order to use the - acceleration in :any:`ot.gpu` you need first to import is with - :code:`import ot.gpu`. - - See `Issue #85 `__ and :any:`ot.gpu` - for more details. References @@ -1172,3 +1163,52 @@ References .. [30] Flamary, Rémi, et al. "Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching." NIPS Workshop on Optimal Transport and Machine Learning OTML. 2014. + +.. [31] Bonneel, Nicolas, et al. `Sliced and radon wasserstein barycenters of + measures + `_\ + , Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 + +.. [32] Huang, M., Ma S., Lai, L. (2021). `A Riemannian Block Coordinate Descent Method for Computing the Projection Robust Wasserstein Distance `_\ , Proceedings of the 38th International Conference on Machine Learning (ICML). + +.. [33] Kerdoncuff T., Emonet R., Marc S. `Sampled Gromov Wasserstein + `_\ , Machine + Learning Journal (MJL), 2021 + +.. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & + Peyré, G. (2019, April). `Interpolating between optimal transport and MMD + using Sinkhorn divergences + `_. In The 22nd + International Conference on Artificial Intelligence and Statistics (pp. + 2681-2690). PMLR. + +.. [35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N., Koyejo, S., + & Schwing, A. G. (2019). `Max-sliced wasserstein distance and its use + for gans + `_. + In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10648-10656). + +.. [36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R. + (2019, May). `Sliced-Wasserstein flows: Nonparametric generative modeling via + optimal transport and diffusions + `_. In International + Conference on Machine Learning (pp. 4104-4113). PMLR. + +.. [37] Janati, H., Cuturi, M., Gramfort, A. `Debiased sinkhorn barycenters + `_ Proceedings of + the 37th International Conference on Machine Learning, PMLR 119:4692-4701, 2020 + +.. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, `Online + Graph Dictionary Learning `_\ , + International Conference on Machine Learning (ICML), 2021. + +.. [39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). + `Kantorovich duality for general transport costs and applications + `_. + Journal of Functional Analysis, 273(11), 3327-3405. + +.. [40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, G., & + Weed, J. (2019, April). `Statistical optimal transport via factored + couplings `_. In + The 22nd International Conference on Artificial Intelligence and Statistics + (pp. 2454-2465). PMLR. diff --git a/ot/gpu/__init__.py b/ot/gpu/__init__.py deleted file mode 100644 index 12db605..0000000 --- a/ot/gpu/__init__.py +++ /dev/null @@ -1,50 +0,0 @@ -# -*- coding: utf-8 -*- -""" -GPU implementation for several OT solvers and utility -functions. - -The GPU backend in handled by `cupy -`_. - -.. warning:: - This module is now deprecated and will be removed in future releases. POT - now privides a backend mechanism that allows for solving prolem on GPU wth - the pytorch backend. - - -.. warning:: - Note that by default the module is not imported in :mod:`ot`. In order to - use it you need to explicitely import :mod:`ot.gpu` . - -By default, the functions in this module accept and return numpy arrays -in order to proide drop-in replacement for the other POT function but -the transfer between CPU en GPU comes with a significant overhead. - -In order to get the best performances, we recommend to give only cupy -arrays to the functions and desactivate the conversion to numpy of the -result of the function with parameter ``to_numpy=False``. - -""" - -# Author: Remi Flamary -# Leo Gautheron -# -# License: MIT License - -import warnings - -from . import bregman -from . import da -from .bregman import sinkhorn -from .da import sinkhorn_lpl1_mm - -from . import utils -from .utils import dist, to_gpu, to_np - - -warnings.warn('This module is deprecated and will be removed in the next minor release of POT', category=DeprecationWarning) - - -__all__ = ["utils", "dist", "sinkhorn", - "sinkhorn_lpl1_mm", 'bregman', 'da', 'to_gpu', 'to_np'] - diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py deleted file mode 100644 index 76af00e..0000000 --- a/ot/gpu/bregman.py +++ /dev/null @@ -1,196 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Bregman projections for regularized OT with GPU -""" - -# Author: Remi Flamary -# Leo Gautheron -# -# License: MIT License - -import cupy as np # np used for matrix computation -import cupy as cp # cp used for cupy specific operations -from . import utils - - -def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, - verbose=False, log=False, to_numpy=True, **kwargs): - r""" - Solve the entropic regularization optimal transport on GPU - - If the input matrix are in numpy format, they will be uploaded to the - GPU first which can incur significant time overhead. - - The function solves the following optimization problem: - - .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) - - s.t. \gamma 1 = a - - \gamma^T 1= b - - \gamma\geq 0 - where : - - - M is the (ns,nt) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights (sum to 1) - - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_ - - - Parameters - ---------- - a : np.ndarray (ns,) - samples weights in the source domain - b : np.ndarray (nt,) or np.ndarray (nt,nbb) - samples in the target domain, compute sinkhorn with multiple targets - and fixed M if b is a matrix (return OT loss + dual variables in log) - M : np.ndarray (ns,nt) - loss matrix - reg : float - Regularization term >0 - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - to_numpy : boolean, optional (default True) - If true convert back the GPU array result to numpy format. - - - Returns - ------- - gamma : (ns x nt) ndarray - Optimal transportation matrix for the given parameters - log : dict - log dictionary return only if log==True in parameters - - - References - ---------- - - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 - - - See Also - -------- - ot.lp.emd : Unregularized OT - ot.optim.cg : General regularized OT - - """ - - a = cp.asarray(a) - b = cp.asarray(b) - M = cp.asarray(M) - - if len(a) == 0: - a = np.ones((M.shape[0],)) / M.shape[0] - if len(b) == 0: - b = np.ones((M.shape[1],)) / M.shape[1] - - # init data - Nini = len(a) - Nfin = len(b) - - if len(b.shape) > 1: - nbb = b.shape[1] - else: - nbb = 0 - - if log: - log = {'err': []} - - # we assume that no distances are null except those of the diagonal of - # distances - if nbb: - u = np.ones((Nini, nbb)) / Nini - v = np.ones((Nfin, nbb)) / Nfin - else: - u = np.ones(Nini) / Nini - v = np.ones(Nfin) / Nfin - - # print(reg) - - # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) - - # print(np.min(K)) - tmp2 = np.empty(b.shape, dtype=M.dtype) - - Kp = (1 / a).reshape(-1, 1) * K - cpt = 0 - err = 1 - while (err > stopThr and cpt < numItermax): - uprev = u - vprev = v - - KtransposeU = np.dot(K.T, u) - v = np.divide(b, KtransposeU) - u = 1. / np.dot(Kp, v) - - if (np.any(KtransposeU == 0) or - np.any(np.isnan(u)) or np.any(np.isnan(v)) or - np.any(np.isinf(u)) or np.any(np.isinf(v))): - # we have reached the machine precision - # come back to previous solution and quit loop - print('Warning: numerical errors at iteration', cpt) - u = uprev - v = vprev - break - if cpt % 10 == 0: - # we can speed up the process by checking for the error only all - # the 10th iterations - if nbb: - err = np.sqrt( - np.sum((u - uprev)**2) / np.sum((u)**2) - + np.sum((v - vprev)**2) / np.sum((v)**2) - ) - else: - # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 - tmp2 = np.sum(u[:, None] * K * v[None, :], 0) - #tmp2=np.einsum('i,ij,j->j', u, K, v) - err = np.linalg.norm(tmp2 - b) # violation of marginal - if log: - log['err'].append(err) - - if verbose: - if cpt % 200 == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - cpt = cpt + 1 - if log: - log['u'] = u - log['v'] = v - - if nbb: # return only loss - #res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) (explodes cupy memory) - res = np.empty(nbb) - for i in range(nbb): - res[i] = np.sum(u[:, None, i] * (K * M) * v[None, :, i]) - if to_numpy: - res = utils.to_np(res) - if log: - return res, log - else: - return res - - else: # return OT matrix - res = u.reshape((-1, 1)) * K * v.reshape((1, -1)) - if to_numpy: - res = utils.to_np(res) - if log: - return res, log - else: - return res - - -# define sinkhorn as sinkhorn_knopp -sinkhorn = sinkhorn_knopp diff --git a/ot/gpu/da.py b/ot/gpu/da.py deleted file mode 100644 index 7adb830..0000000 --- a/ot/gpu/da.py +++ /dev/null @@ -1,144 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Domain adaptation with optimal transport with GPU implementation -""" - -# Author: Remi Flamary -# Nicolas Courty -# Michael Perrot -# Leo Gautheron -# -# License: MIT License - - -import cupy as np # np used for matrix computation -import cupy as cp # cp used for cupy specific operations -import numpy as npp -from . import utils - -from .bregman import sinkhorn - - -def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, - numInnerItermax=200, stopInnerThr=1e-9, verbose=False, - log=False, to_numpy=True): - """ - Solve the entropic regularization optimal transport problem with nonconvex - group lasso regularization on GPU - - If the input matrix are in numpy format, they will be uploaded to the - GPU first which can incur significant time overhead. - - - The function solves the following optimization problem: - - .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma) - + \eta \Omega_g(\gamma) - - s.t. \gamma 1 = a - - \gamma^T 1= b - - \gamma\geq 0 - where : - - - M is the (ns,nt) metric cost matrix - - :math:`\Omega_e` is the entropic regularization term - :math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\Omega_g` is the group lasso regulaization term - :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^{1/2}_1` - where :math:`\mathcal{I}_c` are the index of samples from class c - in the source domain. - - a and b are source and target weights (sum to 1) - - The algorithm used for solving the problem is the generalised conditional - gradient as proposed in [5]_ [7]_ - - - Parameters - ---------- - a : np.ndarray (ns,) - samples weights in the source domain - labels_a : np.ndarray (ns,) - labels of samples in the source domain - b : np.ndarray (nt,) - samples weights in the target domain - M : np.ndarray (ns,nt) - loss matrix - reg : float - Regularization term for entropic regularization >0 - eta : float, optional - Regularization term for group lasso regularization >0 - numItermax : int, optional - Max number of iterations - numInnerItermax : int, optional - Max number of iterations (inner sinkhorn solver) - stopInnerThr : float, optional - Stop threshold on error (inner sinkhorn solver) (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - to_numpy : boolean, optional (default True) - If true convert back the GPU array result to numpy format. - - - Returns - ------- - gamma : (ns x nt) ndarray - Optimal transportation matrix for the given parameters - log : dict - log dictionary return only if log==True in parameters - - - References - ---------- - - .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, - "Optimal Transport for Domain Adaptation," in IEEE - Transactions on Pattern Analysis and Machine Intelligence , - vol.PP, no.99, pp.1-1 - .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). - Generalized conditional gradient: analysis of convergence - and applications. arXiv preprint arXiv:1510.06567. - - See Also - -------- - ot.lp.emd : Unregularized OT - ot.bregman.sinkhorn : Entropic regularized OT - ot.optim.cg : General regularized OT - - """ - - a, labels_a, b, M = utils.to_gpu(a, labels_a, b, M) - - p = 0.5 - epsilon = 1e-3 - - indices_labels = [] - labels_a2 = cp.asnumpy(labels_a) - classes = npp.unique(labels_a2) - for c in classes: - idxc = utils.to_gpu(*npp.where(labels_a2 == c)) - indices_labels.append(idxc) - - W = np.zeros(M.shape) - - for cpt in range(numItermax): - Mreg = M + eta * W - transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax, - stopThr=stopInnerThr, to_numpy=False) - # the transport has been computed. Check if classes are really - # separated - W = np.ones(M.shape) - for (i, c) in enumerate(classes): - - majs = np.sum(transp[indices_labels[i]], axis=0) - majs = p * ((majs + epsilon)**(p - 1)) - W[indices_labels[i]] = majs - - if to_numpy: - return utils.to_np(transp) - else: - return transp diff --git a/ot/gpu/utils.py b/ot/gpu/utils.py deleted file mode 100644 index 41e168a..0000000 --- a/ot/gpu/utils.py +++ /dev/null @@ -1,101 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Utility functions for GPU -""" - -# Author: Remi Flamary -# Nicolas Courty -# Leo Gautheron -# -# License: MIT License - -import cupy as np # np used for matrix computation -import cupy as cp # cp used for cupy specific operations - - -def euclidean_distances(a, b, squared=False, to_numpy=True): - """ - Compute the pairwise euclidean distance between matrices a and b. - - If the input matrix are in numpy format, they will be uploaded to the - GPU first which can incur significant time overhead. - - Parameters - ---------- - a : np.ndarray (n, f) - first matrix - b : np.ndarray (m, f) - second matrix - to_numpy : boolean, optional (default True) - If true convert back the GPU array result to numpy format. - squared : boolean, optional (default False) - if True, return squared euclidean distance matrix - - Returns - ------- - c : (n x m) np.ndarray or cupy.ndarray - pairwise euclidean distance distance matrix - """ - - a, b = to_gpu(a, b) - - a2 = np.sum(np.square(a), 1) - b2 = np.sum(np.square(b), 1) - - c = -2 * np.dot(a, b.T) - c += a2[:, None] - c += b2[None, :] - - if not squared: - np.sqrt(c, out=c) - if to_numpy: - return to_np(c) - else: - return c - - -def dist(x1, x2=None, metric='sqeuclidean', to_numpy=True): - """Compute distance between samples in x1 and x2 on gpu - - Parameters - ---------- - - x1 : np.array (n1,d) - matrix with n1 samples of size d - x2 : np.array (n2,d), optional - matrix with n2 samples of size d (if None then x2=x1) - metric : str - Metric from 'sqeuclidean', 'euclidean', - - - Returns - ------- - - M : np.array (n1,n2) - distance matrix computed with given metric - - """ - if x2 is None: - x2 = x1 - if metric == "sqeuclidean": - return euclidean_distances(x1, x2, squared=True, to_numpy=to_numpy) - elif metric == "euclidean": - return euclidean_distances(x1, x2, squared=False, to_numpy=to_numpy) - else: - raise NotImplementedError - - -def to_gpu(*args): - """ Upload numpy arrays to GPU and return them""" - if len(args) > 1: - return (cp.asarray(x) for x in args) - else: - return cp.asarray(args[0]) - - -def to_np(*args): - """ convert GPU arras to numpy and return them""" - if len(args) > 1: - return (cp.asnumpy(x) for x in args) - else: - return cp.asnumpy(args[0]) diff --git a/test/test_gpu.py b/test/test_gpu.py deleted file mode 100644 index 8e62a74..0000000 --- a/test/test_gpu.py +++ /dev/null @@ -1,106 +0,0 @@ -"""Tests for module gpu for gpu acceleration """ - -# Author: Remi Flamary -# -# License: MIT License - -import numpy as np -import ot -import pytest - -try: # test if cudamat installed - import ot.gpu - nogpu = False -except ImportError: - nogpu = True - - -@pytest.mark.skipif(nogpu, reason="No GPU available") -def test_gpu_old_doctests(): - a = [.5, .5] - b = [.5, .5] - M = [[0., 1.], [1., 0.]] - G = ot.sinkhorn(a, b, M, 1) - np.testing.assert_allclose(G, np.array([[0.36552929, 0.13447071], - [0.13447071, 0.36552929]])) - - -@pytest.mark.skipif(nogpu, reason="No GPU available") -def test_gpu_dist(): - - rng = np.random.RandomState(0) - - for n_samples in [50, 100, 500, 1000]: - print(n_samples) - a = rng.rand(n_samples // 4, 100) - b = rng.rand(n_samples, 100) - - M = ot.dist(a.copy(), b.copy()) - M2 = ot.gpu.dist(a.copy(), b.copy()) - - np.testing.assert_allclose(M, M2, rtol=1e-10) - - M2 = ot.gpu.dist(a.copy(), b.copy(), metric='euclidean', to_numpy=False) - - # check raise not implemented wrong metric - with pytest.raises(NotImplementedError): - M2 = ot.gpu.dist(a.copy(), b.copy(), metric='cityblock', to_numpy=False) - - -@pytest.mark.skipif(nogpu, reason="No GPU available") -def test_gpu_sinkhorn(): - - rng = np.random.RandomState(0) - - for n_samples in [50, 100, 500, 1000]: - a = rng.rand(n_samples // 4, 100) - b = rng.rand(n_samples, 100) - - wa = ot.unif(n_samples // 4) - wb = ot.unif(n_samples) - - wb2 = np.random.rand(n_samples, 20) - wb2 /= wb2.sum(0, keepdims=True) - - M = ot.dist(a.copy(), b.copy()) - M2 = ot.gpu.dist(a.copy(), b.copy(), to_numpy=False) - - reg = 1 - - G = ot.sinkhorn(wa, wb, M, reg) - G1 = ot.gpu.sinkhorn(wa, wb, M, reg) - - np.testing.assert_allclose(G1, G, rtol=1e-10) - - # run all on gpu - ot.gpu.sinkhorn(wa, wb, M2, reg, to_numpy=False, log=True) - - # run sinkhorn for multiple targets - ot.gpu.sinkhorn(wa, wb2, M2, reg, to_numpy=False, log=True) - - -@pytest.mark.skipif(nogpu, reason="No GPU available") -def test_gpu_sinkhorn_lpl1(): - - rng = np.random.RandomState(0) - - for n_samples in [50, 100, 500]: - print(n_samples) - a = rng.rand(n_samples // 4, 100) - labels_a = np.random.randint(10, size=(n_samples // 4)) - b = rng.rand(n_samples, 100) - - wa = ot.unif(n_samples // 4) - wb = ot.unif(n_samples) - - M = ot.dist(a.copy(), b.copy()) - M2 = ot.gpu.dist(a.copy(), b.copy(), to_numpy=False) - - reg = 1 - - G = ot.da.sinkhorn_lpl1_mm(wa, labels_a, wb, M, reg) - G1 = ot.gpu.da.sinkhorn_lpl1_mm(wa, labels_a, wb, M, reg) - - np.testing.assert_allclose(G1, G, rtol=1e-10) - - ot.gpu.da.sinkhorn_lpl1_mm(wa, labels_a, wb, M2, reg, to_numpy=False, log=True) diff --git a/test/test_gromov.py b/test/test_gromov.py index 12fd2b9..9c85b92 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -188,7 +188,7 @@ def test_gromov2_gradients(): @pytest.skip_backend("jax", reason="test very slow with jax backend") @pytest.skip_backend("tf", reason="test very slow with tf backend") def test_entropic_gromov(nx): - n_samples = 50 # nb samples + n_samples = 10 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) @@ -222,9 +222,9 @@ def test_entropic_gromov(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov gw, log = ot.gromov.entropic_gromov_wasserstein2( - C1, C2, p, q, 'kl_loss', epsilon=1e-2, log=True) + C1, C2, p, q, 'kl_loss', max_iter=10, epsilon=1e-2, log=True) gwb, logb = ot.gromov.entropic_gromov_wasserstein2( - C1b, C2b, pb, qb, 'kl_loss', epsilon=1e-2, log=True) + C1b, C2b, pb, qb, 'kl_loss', max_iter=10, epsilon=1e-2, log=True) gwb = nx.to_numpy(gwb) G = log['T'] @@ -245,7 +245,7 @@ def test_entropic_gromov(nx): @pytest.skip_backend("tf", reason="test very slow with tf backend") def test_entropic_gromov_dtype_device(nx): # setup - n_samples = 50 # nb samples + n_samples = 5 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) @@ -280,7 +280,7 @@ def test_entropic_gromov_dtype_device(nx): def test_pointwise_gromov(nx): - n_samples = 50 # nb samples + n_samples = 5 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) @@ -331,14 +331,12 @@ def test_pointwise_gromov(nx): Gb = nx.to_numpy(nx.todense(Gb)) np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(float(logb['gw_dist_estimated']), 0.10342276348494964, atol=1e-8) - np.testing.assert_allclose(float(logb['gw_dist_std']), 0.0015952535464736394, atol=1e-8) @pytest.skip_backend("tf", reason="test very slow with tf backend") @pytest.skip_backend("jax", reason="test very slow with jax backend") def test_sampled_gromov(nx): - n_samples = 50 # nb samples + n_samples = 5 # nb samples mu_s = np.array([0, 0], dtype=np.float64) cov_s = np.array([[1, 0], [0, 1]], dtype=np.float64) @@ -365,9 +363,9 @@ def test_sampled_gromov(nx): return nx.abs(x - y) G, log = ot.gromov.sampled_gromov_wasserstein( - C1, C2, p, q, loss, max_iter=100, epsilon=1, log=True, verbose=True, random_state=42) + C1, C2, p, q, loss, max_iter=20, nb_samples_grad=2, epsilon=1, log=True, verbose=True, random_state=42) Gb, logb = ot.gromov.sampled_gromov_wasserstein( - C1b, C2b, pb, qb, lossb, max_iter=100, epsilon=1, log=True, verbose=True, random_state=42) + C1b, C2b, pb, qb, lossb, max_iter=20, nb_samples_grad=2, epsilon=1, log=True, verbose=True, random_state=42) Gb = nx.to_numpy(Gb) # check constraints @@ -377,13 +375,10 @@ def test_sampled_gromov(nx): np.testing.assert_allclose( q, Gb.sum(0), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(float(logb['gw_dist_estimated']), 0.05679474884977278, atol=1e-08) - np.testing.assert_allclose(float(logb['gw_dist_std']), 0.0005986592106971995, atol=1e-08) - def test_gromov_barycenter(nx): - ns = 10 - nt = 20 + ns = 5 + nt = 8 Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) @@ -450,8 +445,8 @@ def test_gromov_barycenter(nx): @pytest.mark.filterwarnings("ignore:divide") def test_gromov_entropic_barycenter(nx): - ns = 10 - nt = 20 + ns = 5 + nt = 10 Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) @@ -517,7 +512,7 @@ def test_gromov_entropic_barycenter(nx): def test_fgw(nx): - n_samples = 50 # nb samples + n_samples = 20 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) @@ -579,7 +574,7 @@ def test_fgw(nx): def test_fgw2_gradients(): - n_samples = 50 # nb samples + n_samples = 20 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) @@ -625,8 +620,8 @@ def test_fgw2_gradients(): def test_fgw_barycenter(nx): np.random.seed(42) - ns = 50 - nt = 60 + ns = 10 + nt = 20 Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) @@ -674,7 +669,7 @@ def test_fgw_barycenter(nx): def test_gromov_wasserstein_linear_unmixing(nx): - n = 10 + n = 4 X1, y1 = ot.datasets.make_data_classif('3gauss', n, random_state=42) X2, y2 = ot.datasets.make_data_classif('3gauss2', n, random_state=42) @@ -709,10 +704,10 @@ def test_gromov_wasserstein_linear_unmixing(nx): tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 ) - np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06) - np.testing.assert_allclose(unmixing1, [1., 0.], atol=1e-01) - np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=1e-06) - np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01) + np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=5e-06) + np.testing.assert_allclose(unmixing1, [1., 0.], atol=5e-01) + np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=5e-06) + np.testing.assert_allclose(unmixing2, [0., 1.], atol=5e-01) np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-06) np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-06) np.testing.assert_allclose(reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06) @@ -758,7 +753,7 @@ def test_gromov_wasserstein_linear_unmixing(nx): def test_gromov_wasserstein_dictionary_learning(nx): # create dataset composed from 2 structures which are repeated 5 times - shape = 10 + shape = 4 n_samples = 2 n_atoms = 2 projection = 'nonnegative_symmetric' @@ -795,7 +790,7 @@ def test_gromov_wasserstein_dictionary_learning(nx): for i in range(n_samples): _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( Cs[i], Cdict_init, p=ps[i], q=q, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 ) initial_total_reconstruction += reconstruction @@ -803,7 +798,7 @@ def test_gromov_wasserstein_dictionary_learning(nx): Cdict, log = ot.gromov.gromov_wasserstein_dictionary_learning( Cs, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict_init, epochs=epochs, batch_size=2 * n_samples, learning_rate=1., reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose ) # > Compute reconstruction of samples on learned dictionary without backend @@ -811,7 +806,7 @@ def test_gromov_wasserstein_dictionary_learning(nx): for i in range(n_samples): _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( Cs[i], Cdict, p=None, q=None, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 ) total_reconstruction += reconstruction @@ -822,7 +817,7 @@ def test_gromov_wasserstein_dictionary_learning(nx): Cdictb, log = ot.gromov.gromov_wasserstein_dictionary_learning( Csb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdict_initb, epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose ) # Compute reconstruction of samples on learned dictionary @@ -830,7 +825,7 @@ def test_gromov_wasserstein_dictionary_learning(nx): for i in range(n_samples): _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( Csb[i], Cdictb, p=psb[i], q=qb, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 ) total_reconstruction_b += reconstruction @@ -846,7 +841,7 @@ def test_gromov_wasserstein_dictionary_learning(nx): Cdict_bis, log = ot.gromov.gromov_wasserstein_dictionary_learning( Cs, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose ) # > Compute reconstruction of samples on learned dictionary @@ -854,7 +849,7 @@ def test_gromov_wasserstein_dictionary_learning(nx): for i in range(n_samples): _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( Cs[i], Cdict_bis, p=ps[i], q=q, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 ) total_reconstruction_bis += reconstruction @@ -865,7 +860,7 @@ def test_gromov_wasserstein_dictionary_learning(nx): Cdictb_bis, log = ot.gromov.gromov_wasserstein_dictionary_learning( Csb, D=n_atoms, nt=shape, ps=psb, q=qb, Cdict_init=None, epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose ) # > Compute reconstruction of samples on learned dictionary @@ -873,7 +868,7 @@ def test_gromov_wasserstein_dictionary_learning(nx): for i in range(n_samples): _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( Csb[i], Cdictb_bis, p=None, q=None, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 ) total_reconstruction_b_bis += reconstruction @@ -892,7 +887,7 @@ def test_gromov_wasserstein_dictionary_learning(nx): Cdict_bis2, log = ot.gromov.gromov_wasserstein_dictionary_learning( Cs, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict, epochs=epochs, batch_size=n_samples, learning_rate=10., reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose ) # > Compute reconstruction of samples on learned dictionary @@ -900,7 +895,7 @@ def test_gromov_wasserstein_dictionary_learning(nx): for i in range(n_samples): _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( Cs[i], Cdict_bis2, p=ps[i], q=q, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 ) total_reconstruction_bis2 += reconstruction @@ -911,7 +906,7 @@ def test_gromov_wasserstein_dictionary_learning(nx): Cdictb_bis2, log = ot.gromov.gromov_wasserstein_dictionary_learning( Csb, D=n_atoms, nt=shape, ps=psb, q=qb, Cdict_init=Cdictb, epochs=epochs, batch_size=n_samples, learning_rate=10., reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose ) # > Compute reconstruction of samples on learned dictionary @@ -919,7 +914,7 @@ def test_gromov_wasserstein_dictionary_learning(nx): for i in range(n_samples): _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( Csb[i], Cdictb_bis2, p=psb[i], q=qb, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 ) total_reconstruction_b_bis2 += reconstruction @@ -929,7 +924,7 @@ def test_gromov_wasserstein_dictionary_learning(nx): def test_fused_gromov_wasserstein_linear_unmixing(nx): - n = 10 + n = 4 X1, y1 = ot.datasets.make_data_classif('3gauss', n, random_state=42) X2, y2 = ot.datasets.make_data_classif('3gauss2', n, random_state=42) F, y = ot.datasets.make_data_classif('3gauss', n, random_state=42) @@ -947,28 +942,28 @@ def test_fused_gromov_wasserstein_linear_unmixing(nx): unmixing1, C1_emb, Y1_emb, OT, reconstruction1 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( C1, F, Cdict, Ydict, p=p, q=p, alpha=0.5, reg=reg, - tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 ) unmixing1b, C1b_emb, Y1b_emb, OTb, reconstruction1b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( C1b, Fb, Cdictb, Ydictb, p=None, q=None, alpha=0.5, reg=reg, - tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 ) unmixing2, C2_emb, Y2_emb, OT, reconstruction2 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( C2, F, Cdict, Ydict, p=None, q=None, alpha=0.5, reg=reg, - tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 ) unmixing2b, C2b_emb, Y2b_emb, OTb, reconstruction2b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( C2b, Fb, Cdictb, Ydictb, p=pb, q=pb, alpha=0.5, reg=reg, - tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 ) - np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06) - np.testing.assert_allclose(unmixing1, [1., 0.], atol=1e-01) - np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=1e-06) - np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01) + np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=4e-06) + np.testing.assert_allclose(unmixing1, [1., 0.], atol=4e-01) + np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=4e-06) + np.testing.assert_allclose(unmixing2, [0., 1.], atol=4e-01) np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-03) np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-03) np.testing.assert_allclose(Y1_emb, nx.to_numpy(Y1b_emb), atol=1e-03) @@ -983,22 +978,22 @@ def test_fused_gromov_wasserstein_linear_unmixing(nx): unmixing1, C1_emb, Y1_emb, OT, reconstruction1 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( C1, F, Cdict, Ydict, p=p, q=p, alpha=0.5, reg=reg, - tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 ) unmixing1b, C1b_emb, Y1b_emb, OTb, reconstruction1b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( C1b, Fb, Cdictb, Ydictb, p=None, q=None, alpha=0.5, reg=reg, - tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 ) unmixing2, C2_emb, Y2_emb, OT, reconstruction2 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( C2, F, Cdict, Ydict, p=None, q=None, alpha=0.5, reg=reg, - tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 ) unmixing2b, C2b_emb, Y2b_emb, OTb, reconstruction2b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( C2b, Fb, Cdictb, Ydictb, p=pb, q=pb, alpha=0.5, reg=reg, - tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 ) np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06) @@ -1018,7 +1013,7 @@ def test_fused_gromov_wasserstein_linear_unmixing(nx): def test_fused_gromov_wasserstein_dictionary_learning(nx): # create dataset composed from 2 structures which are repeated 5 times - shape = 10 + shape = 4 n_samples = 2 n_atoms = 2 projection = 'nonnegative_symmetric' @@ -1060,7 +1055,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): for i in range(n_samples): _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( Cs[i], Ys[i], Cdict_init, Ydict_init, p=ps[i], q=q, - alpha=alpha, reg=0., tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + alpha=alpha, reg=0., tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 ) initial_total_reconstruction += reconstruction @@ -1069,7 +1064,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): Cdict, Ydict, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( Cs, Ys, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict_init, Ydict_init=Ydict_init, epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose ) # > Compute reconstruction of samples on learned dictionary @@ -1077,7 +1072,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): for i in range(n_samples): _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( Cs[i], Ys[i], Cdict, Ydict, p=None, q=None, alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 ) total_reconstruction += reconstruction # Compare both @@ -1088,7 +1083,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): Cdictb, Ydictb, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdict_initb, Ydict_init=Ydict_initb, epochs=epochs, batch_size=2 * n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose ) # > Compute reconstruction of samples on learned dictionary @@ -1096,7 +1091,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): for i in range(n_samples): _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( Csb[i], Ysb[i], Cdictb, Ydictb, p=psb[i], q=qb, alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 ) total_reconstruction_b += reconstruction @@ -1111,7 +1106,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): Cdict_bis, Ydict_bis, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( Cs, Ys, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, Ydict_init=None, epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose ) # > Compute reconstruction of samples on learned dictionary @@ -1119,7 +1114,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): for i in range(n_samples): _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( Cs[i], Ys[i], Cdict_bis, Ydict_bis, p=ps[i], q=q, alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 ) total_reconstruction_bis += reconstruction @@ -1130,7 +1125,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): Cdictb_bis, Ydictb_bis, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, Ydict_init=None, epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose ) @@ -1139,7 +1134,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): for i in range(n_samples): _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( Csb[i], Ysb[i], Cdictb_bis, Ydictb_bis, p=psb[i], q=qb, alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 ) total_reconstruction_b_bis += reconstruction @@ -1156,7 +1151,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): Cdict_bis2, Ydict_bis2, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( Cs, Ys, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict, Ydict_init=Ydict, epochs=epochs, batch_size=n_samples, learning_rate_C=10., learning_rate_Y=10., alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose ) # > Compute reconstruction of samples on learned dictionary @@ -1164,7 +1159,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): for i in range(n_samples): _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( Cs[i], Ys[i], Cdict_bis2, Ydict_bis2, p=ps[i], q=q, alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 ) total_reconstruction_bis2 += reconstruction @@ -1175,7 +1170,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): Cdictb_bis2, Ydictb_bis2, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdictb, Ydict_init=Ydictb, epochs=epochs, batch_size=n_samples, learning_rate_C=10., learning_rate_Y=10., alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose ) @@ -1184,7 +1179,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): for i in range(n_samples): _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( Csb[i], Ysb[i], Cdictb_bis2, Ydictb_bis2, p=None, q=None, alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 ) total_reconstruction_b_bis2 += reconstruction -- cgit v1.2.3