summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/test_stochastic.py13
1 files changed, 5 insertions, 8 deletions
diff --git a/test/test_stochastic.py b/test/test_stochastic.py
index 88ad666..4bbe230 100644
--- a/test/test_stochastic.py
+++ b/test/test_stochastic.py
@@ -193,17 +193,14 @@ def test_dual_sgd_sinkhorn():
# Test gaussian
n = 30
- n_source = n
- n_target = n
reg = 1
- numItermax = 150000
batch_size = 30
- a = ot.datasets.get_1D_gauss(n_source, m=15, s=5) # m= mean, s= std
- b = ot.datasets.get_1D_gauss(n_target, m=15, s=5)
- X_source = np.arange(n_source,dtype=np.float64)
- Y_target = np.arange(n_target,dtype=np.float64)
- M = ot.dist(X_source.reshape((n_source, 1)), Y_target.reshape((n_target, 1)))
+ a = ot.datasets.get_1D_gauss(n, m=15, s=5) # m= mean, s= std
+ b = ot.datasets.get_1D_gauss(n, m=15, s=5)
+ X_source = np.arange(n, dtype=np.float64)
+ Y_target = np.arange(n, dtype=np.float64)
+ M = ot.dist(X_source.reshape((n, 1)), Y_target.reshape((n, 1)))
M /= M.max()
G_sgd = ot.stochastic.solve_dual_entropic(a, b, M, reg, batch_size,