summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ot/stochastic.py8
-rw-r--r--test/test_stochastic.py4
2 files changed, 6 insertions, 6 deletions
diff --git a/ot/stochastic.py b/ot/stochastic.py
index 0788f61..e33f6a0 100644
--- a/ot/stochastic.py
+++ b/ot/stochastic.py
@@ -435,7 +435,7 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
##############################################################################
-def batch_grad_dual(M, reg, a, b, alpha, beta, batch_size, batch_alpha,
+def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha,
batch_beta):
'''
Computes the partial gradient of F_\W_varepsilon
@@ -528,7 +528,7 @@ def batch_grad_dual(M, reg, a, b, alpha, beta, batch_size, batch_alpha,
return grad_alpha, grad_beta
-def sgd_entropic_regularization(M, reg, a, b, batch_size, numItermax, lr):
+def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
'''
Compute the sgd algorithm to solve the regularized discrete measures
optimal transport dual problem
@@ -612,7 +612,7 @@ def sgd_entropic_regularization(M, reg, a, b, batch_size, numItermax, lr):
k = np.sqrt(cur_iter / 100 + 1)
batch_alpha = np.random.choice(n_source, batch_size, replace=False)
batch_beta = np.random.choice(n_target, batch_size, replace=False)
- update_alpha, update_beta = batch_grad_dual(M, reg, a, b, cur_alpha,
+ update_alpha, update_beta = batch_grad_dual(a, b, M, reg, cur_alpha,
cur_beta, batch_size,
batch_alpha, batch_beta)
cur_alpha += (lr / k) * update_alpha
@@ -698,7 +698,7 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1,
arXiv preprint arxiv:1711.02283.
'''
- opt_alpha, opt_beta = sgd_entropic_regularization(M, reg, a, b, batch_size,
+ opt_alpha, opt_beta = sgd_entropic_regularization(a, b, M, reg, batch_size,
numItermax, lr)
pi = (np.exp((opt_alpha[:, None] + opt_beta[None, :] - M[:, :]) / reg) *
a[:, None] * b[None, :])
diff --git a/test/test_stochastic.py b/test/test_stochastic.py
index 4bbe230..f1d4825 100644
--- a/test/test_stochastic.py
+++ b/test/test_stochastic.py
@@ -196,8 +196,8 @@ def test_dual_sgd_sinkhorn():
reg = 1
batch_size = 30
- a = ot.datasets.get_1D_gauss(n, m=15, s=5) # m= mean, s= std
- b = ot.datasets.get_1D_gauss(n, m=15, s=5)
+ a = ot.datasets.get_1D_gauss(n, 15, 5) # m= mean, s= std
+ b = ot.datasets.get_1D_gauss(n, 15, 5)
X_source = np.arange(n, dtype=np.float64)
Y_target = np.arange(n, dtype=np.float64)
M = ot.dist(X_source.reshape((n, 1)), Y_target.reshape((n, 1)))