From 897026ea1f5c35ba9e881433bc61490e70776b8c Mon Sep 17 00:00:00 2001 From: Huy Tran Date: Wed, 22 Mar 2023 08:13:53 +0100 Subject: [MRG] CO-Optimal Transport solver (#447) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Allow warmstart in sinkhorn and sinkhorn_log * Added argument for warmstart of dual vectors in Sinkhorn-based methods in * Add the number of the PR * [WIP] CO-Optimal Transport * Revert "[WIP] CO-Optimal Transport" This reverts commit f3d36b2705013409ac69b346585e311bc25fcfb7. * reformat with PEP8 * Fix W291 trailing whitespace error in pep8 test * Rearange position of warmstart argument and edit its description * Implementation of CO-Optimal Transport * Optimize code and edit documentation * fix backend bug in test cases * fix backend bug * fix backend bug * Add examples on COOT * Modify API and edit example * Edit API * minor edit of examples and release * fix bug in coot * fix doc examples * more fix of doc * restart CI * reordering ref * add more tests * add more tests * add test verbose * fix PEP8 bug * fix PEP8 bug * fix PEP8 bug * fix pytest bug * edit doc for better display --------- Co-authored-by: RĂ©mi Flamary Co-authored-by: Alexandre Gramfort --- test/test_coot.py | 359 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 359 insertions(+) create mode 100644 test/test_coot.py (limited to 'test/test_coot.py') diff --git a/test/test_coot.py b/test/test_coot.py new file mode 100644 index 0000000..ef68a9b --- /dev/null +++ b/test/test_coot.py @@ -0,0 +1,359 @@ +"""Tests for module COOT on OT """ + +# Author: Quang Huy Tran +# +# License: MIT License + +import numpy as np +import ot +from ot.coot import co_optimal_transport as coot +from ot.coot import co_optimal_transport2 as coot2 +import pytest + + +@pytest.mark.parametrize("verbose", [False, True, 1, 0]) +def test_coot(nx, verbose): + n_samples = 60 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=4) + xt = xs[::-1].copy() + xs_nx = nx.from_numpy(xs) + xt_nx = nx.from_numpy(xt) + + # test couplings + pi_sample, pi_feature = coot(X=xs, Y=xt, verbose=verbose) + pi_sample_nx, pi_feature_nx = coot(X=xs_nx, Y=xt_nx, verbose=verbose) + pi_sample_nx = nx.to_numpy(pi_sample_nx) + pi_feature_nx = nx.to_numpy(pi_feature_nx) + + anti_id_sample = np.flipud(np.eye(n_samples, n_samples)) / n_samples + id_feature = np.eye(2, 2) / 2 + + np.testing.assert_allclose(pi_sample, anti_id_sample, atol=1e-04) + np.testing.assert_allclose(pi_sample_nx, anti_id_sample, atol=1e-04) + np.testing.assert_allclose(pi_feature, id_feature, atol=1e-04) + np.testing.assert_allclose(pi_feature_nx, id_feature, atol=1e-04) + + # test marginal distributions + px_s, px_f = ot.unif(n_samples), ot.unif(2) + py_s, py_f = ot.unif(n_samples), ot.unif(2) + + np.testing.assert_allclose(px_s, pi_sample_nx.sum(0), atol=1e-04) + np.testing.assert_allclose(py_s, pi_sample_nx.sum(1), atol=1e-04) + np.testing.assert_allclose(px_f, pi_feature_nx.sum(0), atol=1e-04) + np.testing.assert_allclose(py_f, pi_feature_nx.sum(1), atol=1e-04) + + np.testing.assert_allclose(px_s, pi_sample.sum(0), atol=1e-04) + np.testing.assert_allclose(py_s, pi_sample.sum(1), atol=1e-04) + np.testing.assert_allclose(px_f, pi_feature.sum(0), atol=1e-04) + np.testing.assert_allclose(py_f, pi_feature.sum(1), atol=1e-04) + + # test COOT distance + + coot_np = coot2(X=xs, Y=xt, verbose=verbose) + coot_nx = nx.to_numpy(coot2(X=xs_nx, Y=xt_nx, verbose=verbose)) + np.testing.assert_allclose(coot_np, 0, atol=1e-08) + np.testing.assert_allclose(coot_nx, 0, atol=1e-08) + + +def test_entropic_coot(nx): + n_samples = 60 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=4) + xt = xs[::-1].copy() + xs_nx = nx.from_numpy(xs) + xt_nx = nx.from_numpy(xt) + + epsilon = (1, 1e-1) + nits_ot = 2000 + + # test couplings + pi_sample, pi_feature = coot(X=xs, Y=xt, epsilon=epsilon, nits_ot=nits_ot) + pi_sample_nx, pi_feature_nx = coot( + X=xs_nx, Y=xt_nx, epsilon=epsilon, nits_ot=nits_ot) + pi_sample_nx = nx.to_numpy(pi_sample_nx) + pi_feature_nx = nx.to_numpy(pi_feature_nx) + + np.testing.assert_allclose(pi_sample, pi_sample_nx, atol=1e-04) + np.testing.assert_allclose(pi_feature, pi_feature_nx, atol=1e-04) + + # test marginal distributions + px_s, px_f = ot.unif(n_samples), ot.unif(2) + py_s, py_f = ot.unif(n_samples), ot.unif(2) + + np.testing.assert_allclose(px_s, pi_sample_nx.sum(0), atol=1e-04) + np.testing.assert_allclose(py_s, pi_sample_nx.sum(1), atol=1e-04) + np.testing.assert_allclose(px_f, pi_feature_nx.sum(0), atol=1e-04) + np.testing.assert_allclose(py_f, pi_feature_nx.sum(1), atol=1e-04) + + np.testing.assert_allclose(px_s, pi_sample.sum(0), atol=1e-04) + np.testing.assert_allclose(py_s, pi_sample.sum(1), atol=1e-04) + np.testing.assert_allclose(px_f, pi_feature.sum(0), atol=1e-04) + np.testing.assert_allclose(py_f, pi_feature.sum(1), atol=1e-04) + + # test entropic COOT distance + + coot_np = coot2(X=xs, Y=xt, epsilon=epsilon, nits_ot=nits_ot) + coot_nx = nx.to_numpy( + coot2(X=xs_nx, Y=xt_nx, epsilon=epsilon, nits_ot=nits_ot)) + + np.testing.assert_allclose(coot_np, coot_nx, atol=1e-08) + + +def test_coot_with_linear_terms(nx): + n_samples = 60 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=4) + xt = xs[::-1].copy() + xs_nx = nx.from_numpy(xs) + xt_nx = nx.from_numpy(xt) + + M_samp = np.ones((n_samples, n_samples)) + np.fill_diagonal(np.fliplr(M_samp), 0) + M_feat = np.ones((2, 2)) + np.fill_diagonal(M_feat, 0) + M_samp_nx, M_feat_nx = nx.from_numpy(M_samp), nx.from_numpy(M_feat) + + alpha = (1, 2) + + # test couplings + anti_id_sample = np.flipud(np.eye(n_samples, n_samples)) / n_samples + id_feature = np.eye(2, 2) / 2 + + pi_sample, pi_feature = coot( + X=xs, Y=xt, alpha=alpha, M_samp=M_samp, M_feat=M_feat) + pi_sample_nx, pi_feature_nx = coot( + X=xs_nx, Y=xt_nx, alpha=alpha, M_samp=M_samp_nx, M_feat=M_feat_nx) + pi_sample_nx = nx.to_numpy(pi_sample_nx) + pi_feature_nx = nx.to_numpy(pi_feature_nx) + + np.testing.assert_allclose(pi_sample, anti_id_sample, atol=1e-04) + np.testing.assert_allclose(pi_sample_nx, anti_id_sample, atol=1e-04) + np.testing.assert_allclose(pi_feature, id_feature, atol=1e-04) + np.testing.assert_allclose(pi_feature_nx, id_feature, atol=1e-04) + + # test marginal distributions + px_s, px_f = ot.unif(n_samples), ot.unif(2) + py_s, py_f = ot.unif(n_samples), ot.unif(2) + + np.testing.assert_allclose(px_s, pi_sample_nx.sum(0), atol=1e-04) + np.testing.assert_allclose(py_s, pi_sample_nx.sum(1), atol=1e-04) + np.testing.assert_allclose(px_f, pi_feature_nx.sum(0), atol=1e-04) + np.testing.assert_allclose(py_f, pi_feature_nx.sum(1), atol=1e-04) + + np.testing.assert_allclose(px_s, pi_sample.sum(0), atol=1e-04) + np.testing.assert_allclose(py_s, pi_sample.sum(1), atol=1e-04) + np.testing.assert_allclose(px_f, pi_feature.sum(0), atol=1e-04) + np.testing.assert_allclose(py_f, pi_feature.sum(1), atol=1e-04) + + # test COOT distance + + coot_np = coot2(X=xs, Y=xt, alpha=alpha, M_samp=M_samp, M_feat=M_feat) + coot_nx = nx.to_numpy( + coot2(X=xs_nx, Y=xt_nx, alpha=alpha, M_samp=M_samp_nx, M_feat=M_feat_nx)) + np.testing.assert_allclose(coot_np, 0, atol=1e-08) + np.testing.assert_allclose(coot_nx, 0, atol=1e-08) + + +def test_coot_raise_value_error(nx): + n_samples = 80 # nb samples + + mu_s = np.array([2, 4]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=43) + xt = xs[::-1].copy() + xs_nx = nx.from_numpy(xs) + xt_nx = nx.from_numpy(xt) + + # raise value error of method sinkhorn + def coot_sh(method_sinkhorn): + return coot(X=xs, Y=xt, method_sinkhorn=method_sinkhorn) + + def coot_sh_nx(method_sinkhorn): + return coot(X=xs_nx, Y=xt_nx, method_sinkhorn=method_sinkhorn) + + np.testing.assert_raises(ValueError, coot_sh, "not_sinkhorn") + np.testing.assert_raises(ValueError, coot_sh_nx, "not_sinkhorn") + + # raise value error for epsilon + def coot_eps(epsilon): + return coot(X=xs, Y=xt, epsilon=epsilon) + + def coot_eps_nx(epsilon): + return coot(X=xs_nx, Y=xt_nx, epsilon=epsilon) + + np.testing.assert_raises(ValueError, coot_eps, (1, 2, 3)) + np.testing.assert_raises(ValueError, coot_eps_nx, [1, 2, 3, 4]) + + # raise value error for alpha + def coot_alpha(alpha): + return coot(X=xs, Y=xt, alpha=alpha) + + def coot_alpha_nx(alpha): + return coot(X=xs_nx, Y=xt_nx, alpha=alpha) + + np.testing.assert_raises(ValueError, coot_alpha, [1]) + np.testing.assert_raises(ValueError, coot_alpha_nx, np.arange(4)) + + +def test_coot_warmstart(nx): + n_samples = 80 # nb samples + + mu_s = np.array([2, 3]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=125) + xt = xs[::-1].copy() + xs_nx = nx.from_numpy(xs) + xt_nx = nx.from_numpy(xt) + + # initialize warmstart + init_pi_sample = np.random.rand(n_samples, n_samples) + init_pi_sample = init_pi_sample / np.sum(init_pi_sample) + init_pi_sample_nx = nx.from_numpy(init_pi_sample) + + init_pi_feature = np.random.rand(2, 2) + init_pi_feature /= init_pi_feature / np.sum(init_pi_feature) + init_pi_feature_nx = nx.from_numpy(init_pi_feature) + + init_duals_sample = (np.random.random(n_samples) * 2 - 1, + np.random.random(n_samples) * 2 - 1) + init_duals_sample_nx = (nx.from_numpy(init_duals_sample[0]), + nx.from_numpy(init_duals_sample[1])) + + init_duals_feature = (np.random.random(2) * 2 - 1, + np.random.random(2) * 2 - 1) + init_duals_feature_nx = (nx.from_numpy(init_duals_feature[0]), + nx.from_numpy(init_duals_feature[1])) + + warmstart = { + "pi_sample": init_pi_sample, + "pi_feature": init_pi_feature, + "duals_sample": init_duals_sample, + "duals_feature": init_duals_feature + } + + warmstart_nx = { + "pi_sample": init_pi_sample_nx, + "pi_feature": init_pi_feature_nx, + "duals_sample": init_duals_sample_nx, + "duals_feature": init_duals_feature_nx + } + + # test couplings + pi_sample, pi_feature = coot(X=xs, Y=xt, warmstart=warmstart) + pi_sample_nx, pi_feature_nx = coot( + X=xs_nx, Y=xt_nx, warmstart=warmstart_nx) + pi_sample_nx = nx.to_numpy(pi_sample_nx) + pi_feature_nx = nx.to_numpy(pi_feature_nx) + + anti_id_sample = np.flipud(np.eye(n_samples, n_samples)) / n_samples + id_feature = np.eye(2, 2) / 2 + + np.testing.assert_allclose(pi_sample, anti_id_sample, atol=1e-04) + np.testing.assert_allclose(pi_sample_nx, anti_id_sample, atol=1e-04) + np.testing.assert_allclose(pi_feature, id_feature, atol=1e-04) + np.testing.assert_allclose(pi_feature_nx, id_feature, atol=1e-04) + + # test marginal distributions + px_s, px_f = ot.unif(n_samples), ot.unif(2) + py_s, py_f = ot.unif(n_samples), ot.unif(2) + + np.testing.assert_allclose(px_s, pi_sample_nx.sum(0), atol=1e-04) + np.testing.assert_allclose(py_s, pi_sample_nx.sum(1), atol=1e-04) + np.testing.assert_allclose(px_f, pi_feature_nx.sum(0), atol=1e-04) + np.testing.assert_allclose(py_f, pi_feature_nx.sum(1), atol=1e-04) + + np.testing.assert_allclose(px_s, pi_sample.sum(0), atol=1e-04) + np.testing.assert_allclose(py_s, pi_sample.sum(1), atol=1e-04) + np.testing.assert_allclose(px_f, pi_feature.sum(0), atol=1e-04) + np.testing.assert_allclose(py_f, pi_feature.sum(1), atol=1e-04) + + # test COOT distance + coot_np = coot2(X=xs, Y=xt, warmstart=warmstart) + coot_nx = nx.to_numpy(coot2(X=xs_nx, Y=xt_nx, warmstart=warmstart_nx)) + np.testing.assert_allclose(coot_np, 0, atol=1e-08) + np.testing.assert_allclose(coot_nx, 0, atol=1e-08) + + +def test_coot_log(nx): + n_samples = 90 # nb samples + + mu_s = np.array([-2, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=43) + xt = xs[::-1].copy() + xs_nx = nx.from_numpy(xs) + xt_nx = nx.from_numpy(xt) + + pi_sample, pi_feature, log = coot(X=xs, Y=xt, log=True) + pi_sample_nx, pi_feature_nx, log_nx = coot(X=xs_nx, Y=xt_nx, log=True) + + duals_sample, duals_feature = log["duals_sample"], log["duals_feature"] + assert len(duals_sample) == 2 + assert len(duals_feature) == 2 + assert len(duals_sample[0]) == n_samples + assert len(duals_sample[1]) == n_samples + assert len(duals_feature[0]) == 2 + assert len(duals_feature[1]) == 2 + + duals_sample_nx = log_nx["duals_sample"] + assert len(duals_sample_nx) == 2 + assert len(duals_sample_nx[0]) == n_samples + assert len(duals_sample_nx[1]) == n_samples + + duals_feature_nx = log_nx["duals_feature"] + assert len(duals_feature_nx) == 2 + assert len(duals_feature_nx[0]) == 2 + assert len(duals_feature_nx[1]) == 2 + + list_coot = log["distances"] + assert len(list_coot) >= 1 + + list_coot_nx = log_nx["distances"] + assert len(list_coot_nx) >= 1 + + # test with coot distance + coot_np, log = coot2(X=xs, Y=xt, log=True) + coot_nx, log_nx = coot2(X=xs_nx, Y=xt_nx, log=True) + + duals_sample, duals_feature = log["duals_sample"], log["duals_feature"] + assert len(duals_sample) == 2 + assert len(duals_feature) == 2 + assert len(duals_sample[0]) == n_samples + assert len(duals_sample[1]) == n_samples + assert len(duals_feature[0]) == 2 + assert len(duals_feature[1]) == 2 + + duals_sample_nx = log_nx["duals_sample"] + assert len(duals_sample_nx) == 2 + assert len(duals_sample_nx[0]) == n_samples + assert len(duals_sample_nx[1]) == n_samples + + duals_feature_nx = log_nx["duals_feature"] + assert len(duals_feature_nx) == 2 + assert len(duals_feature_nx[0]) == 2 + assert len(duals_feature_nx[1]) == 2 + + list_coot = log["distances"] + assert len(list_coot) >= 1 + + list_coot_nx = log_nx["distances"] + assert len(list_coot_nx) >= 1 -- cgit v1.2.3