From e068b58ba4234792d96287afd34c3cddef544dd4 Mon Sep 17 00:00:00 2001 From: Kilian Fatras Date: Mon, 18 Jun 2018 18:05:48 -0700 Subject: pep8 --- examples/plot_stochastic.py | 6 +++--- ot/stochastic.py | 33 +++++++++++++++------------------ test/test_stochastic.py | 4 ++-- 3 files changed, 20 insertions(+), 23 deletions(-) diff --git a/examples/plot_stochastic.py b/examples/plot_stochastic.py index 3fc1955..e139473 100644 --- a/examples/plot_stochastic.py +++ b/examples/plot_stochastic.py @@ -53,7 +53,7 @@ M = ot.dist(X_source, Y_target) method = "SAG" sag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, - numItermax, lr) + numItermax, lr) print(sag_pi) ############################################################################# @@ -91,7 +91,7 @@ M = ot.dist(X_source, Y_target) method = "ASGD" asgd_pi, log = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, - numItermax, lr, log) + numItermax, lr, log) print(log['alpha'], log['beta']) print(asgd_pi) @@ -174,7 +174,7 @@ M = ot.dist(X_source, Y_target) # Call ot.solve_dual_entropic and plot the results. sgd_dual_pi, log = ot.stochastic.solve_dual_entropic(a, b, M, reg, batch_size, - numItermax, lr, log) + numItermax, lr, log) print(log['alpha'], log['beta']) print(sgd_dual_pi) diff --git a/ot/stochastic.py b/ot/stochastic.py index 31c99be..98537d9 100644 --- a/ot/stochastic.py +++ b/ot/stochastic.py @@ -78,8 +78,8 @@ def coordinate_gradient(b, M, reg, beta, i): ''' r = M[i, :] - beta - exp_beta = np.exp(-r/reg) * b - khi = exp_beta/(np.sum(exp_beta)) + exp_beta = np.exp(-r / reg) * b + khi = exp_beta / (np.sum(exp_beta)) return b - khi @@ -164,7 +164,7 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=0.1): cur_coord_grad = a[i] * coordinate_gradient(b, M, reg, cur_beta, i) sum_stored_gradient += (cur_coord_grad - stored_gradient[i]) stored_gradient[i] = cur_coord_grad - cur_beta += lr * (1./n_source) * sum_stored_gradient + cur_beta += lr * (1. / n_source) * sum_stored_gradient return cur_beta @@ -246,8 +246,8 @@ def averaged_sgd_entropic_transport(b, M, reg, numItermax=300000, lr=1): k = cur_iter + 1 i = np.random.randint(n_source) cur_coord_grad = coordinate_gradient(b, M, reg, cur_beta, i) - cur_beta += (lr/np.sqrt(k)) * cur_coord_grad - ave_beta = (1./k) * cur_beta + (1 - 1./k) * ave_beta + cur_beta += (lr / np.sqrt(k)) * cur_coord_grad + ave_beta = (1. / k) * cur_beta + (1 - 1. / k) * ave_beta return ave_beta @@ -316,12 +316,11 @@ def c_transform_entropic(b, M, reg, beta): ''' n_source = np.shape(M)[0] - n_target = np.shape(M)[1] alpha = np.zeros(n_source) for i in range(n_source): r = M[i, :] - beta min_r = np.min(r) - exp_beta = np.exp(-(r - min_r)/reg) * b + exp_beta = np.exp(-(r - min_r) / reg) * b alpha[i] = min_r - reg * np.log(np.sum(exp_beta)) return alpha @@ -407,8 +406,6 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=0.1, Advances in Neural Information Processing Systems (2016), arXiv preprint arxiv:1605.08527. ''' - n_source = 7 - n_target = 4 if method.lower() == "sag": opt_beta = sag_entropic_transport(a, b, M, reg, numItermax, lr) elif method.lower() == "asgd": @@ -418,7 +415,7 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=0.1, return None opt_alpha = c_transform_entropic(b, M, reg, opt_beta) - pi = (np.exp((opt_alpha[:, None] + opt_beta[None, :] - M[:, :])/reg) * + pi = (np.exp((opt_alpha[:, None] + opt_beta[None, :] - M[:, :]) / reg) * a[:, None] * b[None, :]) if log: @@ -511,8 +508,8 @@ def grad_dF_dalpha(M, reg, alpha, beta, batch_size, batch_alpha, batch_beta): grad_alpha = np.zeros(batch_size) grad_alpha[:] = batch_size for j in batch_beta: - grad_alpha -= np.exp((alpha[batch_alpha] + beta[j] - - M[batch_alpha, j])/reg) + grad_alpha -= np.exp((alpha[batch_alpha] + beta[j] - + M[batch_alpha, j]) / reg) return grad_alpha @@ -594,7 +591,7 @@ def grad_dF_dbeta(M, reg, alpha, beta, batch_size, batch_alpha, batch_beta): grad_beta[:] = batch_size for i in batch_alpha: grad_beta -= np.exp((alpha[i] + - beta[batch_beta] - M[i, batch_beta])/reg) + beta[batch_beta] - M[i, batch_beta]) / reg) return grad_beta @@ -681,10 +678,10 @@ def sgd_entropic_regularization(M, reg, batch_size, numItermax, lr, batch_beta = np.random.choice(n_target, batch_size, replace=False) grad_F_alpha = grad_dF_dalpha(M, reg, cur_alpha, cur_beta, batch_size, batch_alpha, batch_beta) - cur_alpha[batch_alpha] += (lr/k) * grad_F_alpha + cur_alpha[batch_alpha] += (lr / k) * grad_F_alpha grad_F_beta = grad_dF_dbeta(M, reg, cur_alpha, cur_beta, batch_size, batch_alpha, batch_beta) - cur_beta[batch_beta] += (lr/k) * grad_F_beta + cur_beta[batch_beta] += (lr / k) * grad_F_beta else: for cur_iter in range(numItermax): @@ -695,8 +692,8 @@ def sgd_entropic_regularization(M, reg, batch_size, numItermax, lr, batch_size, batch_alpha, batch_beta) grad_F_beta = grad_dF_dbeta(M, reg, cur_alpha, cur_beta, batch_size, batch_alpha, batch_beta) - cur_alpha[batch_alpha] += (lr/k) * grad_F_alpha - cur_beta[batch_beta] += (lr/k) * grad_F_beta + cur_alpha[batch_alpha] += (lr / k) * grad_F_alpha + cur_beta[batch_beta] += (lr / k) * grad_F_beta return cur_alpha, cur_beta @@ -779,7 +776,7 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1, opt_alpha, opt_beta = sgd_entropic_regularization(M, reg, batch_size, numItermax, lr) - pi = (np.exp((opt_alpha[:, None] + opt_beta[None, :] - M[:, :])/reg) * + pi = (np.exp((opt_alpha[:, None] + opt_beta[None, :] - M[:, :]) / reg) * a[:, None] * b[None, :]) if log: log = {} diff --git a/test/test_stochastic.py b/test/test_stochastic.py index bc0cebb..5824df1 100644 --- a/test/test_stochastic.py +++ b/test/test_stochastic.py @@ -149,7 +149,7 @@ def test_stochastic_dual_sgd(): M = ot.dist(x, x) G = ot.stochastic.solve_dual_entropic(u, u, M, reg, batch_size, - numItermax=numItermax) + numItermax=numItermax) # check constratints np.testing.assert_allclose( @@ -180,7 +180,7 @@ def test_dual_sgd_sinkhorn(): M = ot.dist(x, x) G_sgd = ot.stochastic.solve_dual_entropic(u, u, M, reg, batch_size, - numItermax=nb_iter) + numItermax=nb_iter) G_sinkhorn = ot.sinkhorn(u, u, M, reg) -- cgit v1.2.3