diff options
-rw-r--r-- | test/test_stochastic.py | 13 |
1 files changed, 5 insertions, 8 deletions
diff --git a/test/test_stochastic.py b/test/test_stochastic.py index 88ad666..4bbe230 100644 --- a/test/test_stochastic.py +++ b/test/test_stochastic.py @@ -193,17 +193,14 @@ def test_dual_sgd_sinkhorn(): # Test gaussian n = 30 - n_source = n - n_target = n reg = 1 - numItermax = 150000 batch_size = 30 - a = ot.datasets.get_1D_gauss(n_source, m=15, s=5) # m= mean, s= std - b = ot.datasets.get_1D_gauss(n_target, m=15, s=5) - X_source = np.arange(n_source,dtype=np.float64) - Y_target = np.arange(n_target,dtype=np.float64) - M = ot.dist(X_source.reshape((n_source, 1)), Y_target.reshape((n_target, 1))) + a = ot.datasets.get_1D_gauss(n, m=15, s=5) # m= mean, s= std + b = ot.datasets.get_1D_gauss(n, m=15, s=5) + X_source = np.arange(n, dtype=np.float64) + Y_target = np.arange(n, dtype=np.float64) + M = ot.dist(X_source.reshape((n, 1)), Y_target.reshape((n, 1))) M /= M.max() G_sgd = ot.stochastic.solve_dual_entropic(a, b, M, reg, batch_size, |