diff options
Diffstat (limited to 'test/test_bregman.py')
-rw-r--r-- | test/test_bregman.py | 81 |
1 files changed, 20 insertions, 61 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py index 1419f9b..6c37984 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -155,8 +155,7 @@ def test_sinkhorn_backends(nx): G = ot.sinkhorn(a, a, M, 1) - ab = nx.from_numpy(a) - M_nx = nx.from_numpy(M) + ab, M_nx = nx.from_numpy(a, M) Gb = ot.sinkhorn(ab, ab, M_nx, 1) @@ -176,8 +175,7 @@ def test_sinkhorn2_backends(nx): G = ot.sinkhorn(a, a, M, 1) - ab = nx.from_numpy(a) - M_nx = nx.from_numpy(M) + ab, M_nx = nx.from_numpy(a, M) Gb = ot.sinkhorn2(ab, ab, M_nx, 1) @@ -260,8 +258,7 @@ def test_sinkhorn_variants(nx): M = ot.dist(x, x) - ub = nx.from_numpy(u) - M_nx = nx.from_numpy(M) + ub, M_nx = nx.from_numpy(u, M) G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10) Gl = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) @@ -298,8 +295,7 @@ def test_sinkhorn_variants_dtype_device(nx, method): for tp in nx.__type_list__: print(nx.dtype_device(tp)) - ub = nx.from_numpy(u, type_as=tp) - Mb = nx.from_numpy(M, type_as=tp) + ub, Mb = nx.from_numpy(u, M, type_as=tp) Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10) @@ -318,8 +314,7 @@ def test_sinkhorn2_variants_dtype_device(nx, method): for tp in nx.__type_list__: print(nx.dtype_device(tp)) - ub = nx.from_numpy(u, type_as=tp) - Mb = nx.from_numpy(M, type_as=tp) + ub, Mb = nx.from_numpy(u, M, type_as=tp) lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10) @@ -337,8 +332,7 @@ def test_sinkhorn2_variants_device_tf(method): # Check that everything stays on the CPU with tf.device("/CPU:0"): - ub = nx.from_numpy(u) - Mb = nx.from_numpy(M) + ub, Mb = nx.from_numpy(u, M) Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10) lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10) nx.assert_same_dtype_device(Mb, Gb) @@ -346,8 +340,7 @@ def test_sinkhorn2_variants_device_tf(method): if len(tf.config.list_physical_devices('GPU')) > 0: # Check that everything happens on the GPU - ub = nx.from_numpy(u) - Mb = nx.from_numpy(M) + ub, Mb = nx.from_numpy(u, M) Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10) lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10) nx.assert_same_dtype_device(Mb, Gb) @@ -370,9 +363,7 @@ def test_sinkhorn_variants_multi_b(nx): M = ot.dist(x, x) - ub = nx.from_numpy(u) - bb = nx.from_numpy(b) - M_nx = nx.from_numpy(M) + ub, bb, M_nx = nx.from_numpy(u, b, M) G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10) Gl = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) @@ -400,9 +391,7 @@ def test_sinkhorn2_variants_multi_b(nx): M = ot.dist(x, x) - ub = nx.from_numpy(u) - bb = nx.from_numpy(b) - M_nx = nx.from_numpy(M) + ub, bb, M_nx = nx.from_numpy(u, b, M) G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10) Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) @@ -483,9 +472,7 @@ def test_barycenter(nx, method, verbose, warn): alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) - A_nx = nx.from_numpy(A) - M_nx = nx.from_numpy(M) - weights_nx = nx.from_numpy(weights) + A_nx, M_nx, weights_nx = nx.from_numpy(A, M, weights) reg = 1e-2 if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": @@ -523,9 +510,7 @@ def test_barycenter_debiased(nx, method, verbose, warn): alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) - A_nx = nx.from_numpy(A) - M_nx = nx.from_numpy(M) - weights_nx = nx.from_numpy(weights) + A_nx, M_nx, weights_nx = nx.from_numpy(A, M, weights) # wasserstein reg = 1e-2 @@ -594,9 +579,7 @@ def test_barycenter_stabilization(nx): alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) - A_nx = nx.from_numpy(A) - M_nx = nx.from_numpy(M) - weights_b = nx.from_numpy(weights) + A_nx, M_nx, weights_b = nx.from_numpy(A, M, weights) # wasserstein reg = 1e-2 @@ -697,11 +680,7 @@ def test_unmix(nx): M0 /= M0.max() h0 = ot.unif(2) - ab = nx.from_numpy(a) - Db = nx.from_numpy(D) - M_nx = nx.from_numpy(M) - M0b = nx.from_numpy(M0) - h0b = nx.from_numpy(h0) + ab, Db, M_nx, M0b, h0b = nx.from_numpy(a, D, M, M0, h0) # wasserstein reg = 1e-3 @@ -727,12 +706,7 @@ def test_empirical_sinkhorn(nx): M = ot.dist(X_s, X_t) M_m = ot.dist(X_s, X_t, metric='euclidean') - ab = nx.from_numpy(a) - bb = nx.from_numpy(b) - X_sb = nx.from_numpy(X_s) - X_tb = nx.from_numpy(X_t) - M_nx = nx.from_numpy(M, type_as=ab) - M_mb = nx.from_numpy(M_m, type_as=ab) + ab, bb, X_sb, X_tb, M_nx, M_mb = nx.from_numpy(a, b, X_s, X_t, M, M_m) G_sqe = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1)) sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1)) @@ -776,12 +750,7 @@ def test_lazy_empirical_sinkhorn(nx): M = ot.dist(X_s, X_t) M_m = ot.dist(X_s, X_t, metric='euclidean') - ab = nx.from_numpy(a) - bb = nx.from_numpy(b) - X_sb = nx.from_numpy(X_s) - X_tb = nx.from_numpy(X_t) - M_nx = nx.from_numpy(M, type_as=ab) - M_mb = nx.from_numpy(M_m, type_as=ab) + ab, bb, X_sb, X_tb, M_nx, M_mb = nx.from_numpy(a, b, X_s, X_t, M, M_m) f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True) f, g = nx.to_numpy(f), nx.to_numpy(g) @@ -825,19 +794,13 @@ def test_empirical_sinkhorn_divergence(nx): a = np.linspace(1, n, n) a /= a.sum() b = ot.unif(n) - X_s = np.reshape(np.arange(n), (n, 1)) - X_t = np.reshape(np.arange(0, n * 2, 2), (n, 1)) + X_s = np.reshape(np.arange(n, dtype=np.float64), (n, 1)) + X_t = np.reshape(np.arange(0, n * 2, 2, dtype=np.float64), (n, 1)) M = ot.dist(X_s, X_t) M_s = ot.dist(X_s, X_s) M_t = ot.dist(X_t, X_t) - ab = nx.from_numpy(a) - bb = nx.from_numpy(b) - X_sb = nx.from_numpy(X_s) - X_tb = nx.from_numpy(X_t) - M_nx = nx.from_numpy(M, type_as=ab) - M_sb = nx.from_numpy(M_s, type_as=ab) - M_tb = nx.from_numpy(M_t, type_as=ab) + ab, bb, X_sb, X_tb, M_nx, M_sb, M_tb = nx.from_numpy(a, b, X_s, X_t, M, M_s, M_t) emp_sinkhorn_div = nx.to_numpy(ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb)) sinkhorn_div = nx.to_numpy( @@ -872,9 +835,7 @@ def test_stabilized_vs_sinkhorn_multidim(nx): M /= np.median(M) epsilon = 0.1 - ab = nx.from_numpy(a) - bb = nx.from_numpy(b) - M_nx = nx.from_numpy(M, type_as=ab) + ab, bb, M_nx = nx.from_numpy(a, b, M) G_np, _ = ot.bregman.sinkhorn(a, b, M, reg=epsilon, method="sinkhorn", log=True) G, log = ot.bregman.sinkhorn(ab, bb, M_nx, reg=epsilon, @@ -936,9 +897,7 @@ def test_screenkhorn(nx): x = rng.randn(n, 2) M = ot.dist(x, x) - ab = nx.from_numpy(a) - bb = nx.from_numpy(b) - M_nx = nx.from_numpy(M, type_as=ab) + ab, bb, M_nx = nx.from_numpy(a, b, M) # sinkhorn G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-1)) |