summaryrefslogtreecommitdiff
path: root/test/test_bregman.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_bregman.py')
-rw-r--r--test/test_bregman.py81
1 files changed, 20 insertions, 61 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 1419f9b..6c37984 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -155,8 +155,7 @@ def test_sinkhorn_backends(nx):
G = ot.sinkhorn(a, a, M, 1)
- ab = nx.from_numpy(a)
- M_nx = nx.from_numpy(M)
+ ab, M_nx = nx.from_numpy(a, M)
Gb = ot.sinkhorn(ab, ab, M_nx, 1)
@@ -176,8 +175,7 @@ def test_sinkhorn2_backends(nx):
G = ot.sinkhorn(a, a, M, 1)
- ab = nx.from_numpy(a)
- M_nx = nx.from_numpy(M)
+ ab, M_nx = nx.from_numpy(a, M)
Gb = ot.sinkhorn2(ab, ab, M_nx, 1)
@@ -260,8 +258,7 @@ def test_sinkhorn_variants(nx):
M = ot.dist(x, x)
- ub = nx.from_numpy(u)
- M_nx = nx.from_numpy(M)
+ ub, M_nx = nx.from_numpy(u, M)
G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10)
Gl = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
@@ -298,8 +295,7 @@ def test_sinkhorn_variants_dtype_device(nx, method):
for tp in nx.__type_list__:
print(nx.dtype_device(tp))
- ub = nx.from_numpy(u, type_as=tp)
- Mb = nx.from_numpy(M, type_as=tp)
+ ub, Mb = nx.from_numpy(u, M, type_as=tp)
Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10)
@@ -318,8 +314,7 @@ def test_sinkhorn2_variants_dtype_device(nx, method):
for tp in nx.__type_list__:
print(nx.dtype_device(tp))
- ub = nx.from_numpy(u, type_as=tp)
- Mb = nx.from_numpy(M, type_as=tp)
+ ub, Mb = nx.from_numpy(u, M, type_as=tp)
lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10)
@@ -337,8 +332,7 @@ def test_sinkhorn2_variants_device_tf(method):
# Check that everything stays on the CPU
with tf.device("/CPU:0"):
- ub = nx.from_numpy(u)
- Mb = nx.from_numpy(M)
+ ub, Mb = nx.from_numpy(u, M)
Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10)
lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10)
nx.assert_same_dtype_device(Mb, Gb)
@@ -346,8 +340,7 @@ def test_sinkhorn2_variants_device_tf(method):
if len(tf.config.list_physical_devices('GPU')) > 0:
# Check that everything happens on the GPU
- ub = nx.from_numpy(u)
- Mb = nx.from_numpy(M)
+ ub, Mb = nx.from_numpy(u, M)
Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10)
lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10)
nx.assert_same_dtype_device(Mb, Gb)
@@ -370,9 +363,7 @@ def test_sinkhorn_variants_multi_b(nx):
M = ot.dist(x, x)
- ub = nx.from_numpy(u)
- bb = nx.from_numpy(b)
- M_nx = nx.from_numpy(M)
+ ub, bb, M_nx = nx.from_numpy(u, b, M)
G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10)
Gl = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
@@ -400,9 +391,7 @@ def test_sinkhorn2_variants_multi_b(nx):
M = ot.dist(x, x)
- ub = nx.from_numpy(u)
- bb = nx.from_numpy(b)
- M_nx = nx.from_numpy(M)
+ ub, bb, M_nx = nx.from_numpy(u, b, M)
G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10)
Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
@@ -483,9 +472,7 @@ def test_barycenter(nx, method, verbose, warn):
alpha = 0.5 # 0<=alpha<=1
weights = np.array([1 - alpha, alpha])
- A_nx = nx.from_numpy(A)
- M_nx = nx.from_numpy(M)
- weights_nx = nx.from_numpy(weights)
+ A_nx, M_nx, weights_nx = nx.from_numpy(A, M, weights)
reg = 1e-2
if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
@@ -523,9 +510,7 @@ def test_barycenter_debiased(nx, method, verbose, warn):
alpha = 0.5 # 0<=alpha<=1
weights = np.array([1 - alpha, alpha])
- A_nx = nx.from_numpy(A)
- M_nx = nx.from_numpy(M)
- weights_nx = nx.from_numpy(weights)
+ A_nx, M_nx, weights_nx = nx.from_numpy(A, M, weights)
# wasserstein
reg = 1e-2
@@ -594,9 +579,7 @@ def test_barycenter_stabilization(nx):
alpha = 0.5 # 0<=alpha<=1
weights = np.array([1 - alpha, alpha])
- A_nx = nx.from_numpy(A)
- M_nx = nx.from_numpy(M)
- weights_b = nx.from_numpy(weights)
+ A_nx, M_nx, weights_b = nx.from_numpy(A, M, weights)
# wasserstein
reg = 1e-2
@@ -697,11 +680,7 @@ def test_unmix(nx):
M0 /= M0.max()
h0 = ot.unif(2)
- ab = nx.from_numpy(a)
- Db = nx.from_numpy(D)
- M_nx = nx.from_numpy(M)
- M0b = nx.from_numpy(M0)
- h0b = nx.from_numpy(h0)
+ ab, Db, M_nx, M0b, h0b = nx.from_numpy(a, D, M, M0, h0)
# wasserstein
reg = 1e-3
@@ -727,12 +706,7 @@ def test_empirical_sinkhorn(nx):
M = ot.dist(X_s, X_t)
M_m = ot.dist(X_s, X_t, metric='euclidean')
- ab = nx.from_numpy(a)
- bb = nx.from_numpy(b)
- X_sb = nx.from_numpy(X_s)
- X_tb = nx.from_numpy(X_t)
- M_nx = nx.from_numpy(M, type_as=ab)
- M_mb = nx.from_numpy(M_m, type_as=ab)
+ ab, bb, X_sb, X_tb, M_nx, M_mb = nx.from_numpy(a, b, X_s, X_t, M, M_m)
G_sqe = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1))
sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1))
@@ -776,12 +750,7 @@ def test_lazy_empirical_sinkhorn(nx):
M = ot.dist(X_s, X_t)
M_m = ot.dist(X_s, X_t, metric='euclidean')
- ab = nx.from_numpy(a)
- bb = nx.from_numpy(b)
- X_sb = nx.from_numpy(X_s)
- X_tb = nx.from_numpy(X_t)
- M_nx = nx.from_numpy(M, type_as=ab)
- M_mb = nx.from_numpy(M_m, type_as=ab)
+ ab, bb, X_sb, X_tb, M_nx, M_mb = nx.from_numpy(a, b, X_s, X_t, M, M_m)
f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True)
f, g = nx.to_numpy(f), nx.to_numpy(g)
@@ -825,19 +794,13 @@ def test_empirical_sinkhorn_divergence(nx):
a = np.linspace(1, n, n)
a /= a.sum()
b = ot.unif(n)
- X_s = np.reshape(np.arange(n), (n, 1))
- X_t = np.reshape(np.arange(0, n * 2, 2), (n, 1))
+ X_s = np.reshape(np.arange(n, dtype=np.float64), (n, 1))
+ X_t = np.reshape(np.arange(0, n * 2, 2, dtype=np.float64), (n, 1))
M = ot.dist(X_s, X_t)
M_s = ot.dist(X_s, X_s)
M_t = ot.dist(X_t, X_t)
- ab = nx.from_numpy(a)
- bb = nx.from_numpy(b)
- X_sb = nx.from_numpy(X_s)
- X_tb = nx.from_numpy(X_t)
- M_nx = nx.from_numpy(M, type_as=ab)
- M_sb = nx.from_numpy(M_s, type_as=ab)
- M_tb = nx.from_numpy(M_t, type_as=ab)
+ ab, bb, X_sb, X_tb, M_nx, M_sb, M_tb = nx.from_numpy(a, b, X_s, X_t, M, M_s, M_t)
emp_sinkhorn_div = nx.to_numpy(ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb))
sinkhorn_div = nx.to_numpy(
@@ -872,9 +835,7 @@ def test_stabilized_vs_sinkhorn_multidim(nx):
M /= np.median(M)
epsilon = 0.1
- ab = nx.from_numpy(a)
- bb = nx.from_numpy(b)
- M_nx = nx.from_numpy(M, type_as=ab)
+ ab, bb, M_nx = nx.from_numpy(a, b, M)
G_np, _ = ot.bregman.sinkhorn(a, b, M, reg=epsilon, method="sinkhorn", log=True)
G, log = ot.bregman.sinkhorn(ab, bb, M_nx, reg=epsilon,
@@ -936,9 +897,7 @@ def test_screenkhorn(nx):
x = rng.randn(n, 2)
M = ot.dist(x, x)
- ab = nx.from_numpy(a)
- bb = nx.from_numpy(b)
- M_nx = nx.from_numpy(M, type_as=ab)
+ ab, bb, M_nx = nx.from_numpy(a, b, M)
# sinkhorn
G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-1))