diff options
author | Kilian Fatras <kilianfatras@dhcp-206-12-53-20.eduroam.wireless.ubc.ca> | 2018-08-28 18:51:28 -0700 |
---|---|---|
committer | Kilian Fatras <kilianfatras@dhcp-206-12-53-20.eduroam.wireless.ubc.ca> | 2018-08-28 18:51:28 -0700 |
commit | 37e3b29595223399ebe4710ac2bb061004814118 (patch) | |
tree | c76a57b9a2412af918f2be87137c93e154264ac7 | |
parent | cd193f78d392143ea9421da0f7e55ca8b75b8a0d (diff) |
fixed argument functions
-rw-r--r-- | ot/stochastic.py | 8 | ||||
-rw-r--r-- | test/test_stochastic.py | 4 |
2 files changed, 6 insertions, 6 deletions
diff --git a/ot/stochastic.py b/ot/stochastic.py index 0788f61..e33f6a0 100644 --- a/ot/stochastic.py +++ b/ot/stochastic.py @@ -435,7 +435,7 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None, ############################################################################## -def batch_grad_dual(M, reg, a, b, alpha, beta, batch_size, batch_alpha, +def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha, batch_beta): ''' Computes the partial gradient of F_\W_varepsilon @@ -528,7 +528,7 @@ def batch_grad_dual(M, reg, a, b, alpha, beta, batch_size, batch_alpha, return grad_alpha, grad_beta -def sgd_entropic_regularization(M, reg, a, b, batch_size, numItermax, lr): +def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr): ''' Compute the sgd algorithm to solve the regularized discrete measures optimal transport dual problem @@ -612,7 +612,7 @@ def sgd_entropic_regularization(M, reg, a, b, batch_size, numItermax, lr): k = np.sqrt(cur_iter / 100 + 1) batch_alpha = np.random.choice(n_source, batch_size, replace=False) batch_beta = np.random.choice(n_target, batch_size, replace=False) - update_alpha, update_beta = batch_grad_dual(M, reg, a, b, cur_alpha, + update_alpha, update_beta = batch_grad_dual(a, b, M, reg, cur_alpha, cur_beta, batch_size, batch_alpha, batch_beta) cur_alpha += (lr / k) * update_alpha @@ -698,7 +698,7 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1, arXiv preprint arxiv:1711.02283. ''' - opt_alpha, opt_beta = sgd_entropic_regularization(M, reg, a, b, batch_size, + opt_alpha, opt_beta = sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr) pi = (np.exp((opt_alpha[:, None] + opt_beta[None, :] - M[:, :]) / reg) * a[:, None] * b[None, :]) diff --git a/test/test_stochastic.py b/test/test_stochastic.py index 4bbe230..f1d4825 100644 --- a/test/test_stochastic.py +++ b/test/test_stochastic.py @@ -196,8 +196,8 @@ def test_dual_sgd_sinkhorn(): reg = 1 batch_size = 30 - a = ot.datasets.get_1D_gauss(n, m=15, s=5) # m= mean, s= std - b = ot.datasets.get_1D_gauss(n, m=15, s=5) + a = ot.datasets.get_1D_gauss(n, 15, 5) # m= mean, s= std + b = ot.datasets.get_1D_gauss(n, 15, 5) X_source = np.arange(n, dtype=np.float64) Y_target = np.arange(n, dtype=np.float64) M = ot.dist(X_source.reshape((n, 1)), Y_target.reshape((n, 1))) |