From a08375c8dc7594e247e586fcc4d65a96771d25c7 Mon Sep 17 00:00:00 2001 From: Romain Tavenard Date: Mon, 1 Jul 2019 15:36:55 +0200 Subject: Fixed all doctests assuming functions are working properly (actually tested in tests/) --- ot/stochastic.py | 155 +++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 99 insertions(+), 56 deletions(-) (limited to 'ot/stochastic.py') diff --git a/ot/stochastic.py b/ot/stochastic.py index 762eb3e..bf3e7a7 100644 --- a/ot/stochastic.py +++ b/ot/stochastic.py @@ -52,19 +52,23 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i): Examples -------- >>> import ot + >>> np.random.seed(0) >>> n_source = 7 >>> n_target = 4 - >>> reg = 1 - >>> numItermax = 300000 >>> a = ot.utils.unif(n_source) >>> b = ot.utils.unif(n_target) - >>> rng = np.random.RandomState(0) - >>> X_source = rng.randn(n_source, 2) - >>> Y_target = rng.randn(n_target, 2) + >>> X_source = np.random.randn(n_source, 2) + >>> Y_target = np.random.randn(n_target, 2) >>> M = ot.dist(X_source, Y_target) - >>> method = "ASGD" - >>> asgd_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, numItermax) - >>> print(asgd_pi) + >>> ot.stochastic.solve_semi_dual_entropic(a, b, M, reg=1, method="ASGD", numItermax=300000) + array([[2.53942342e-02, 9.98640673e-02, 1.75945647e-02, 4.27664307e-06], + [1.21556999e-01, 1.26350515e-02, 1.30491795e-03, 7.36017394e-03], + [3.54070702e-03, 7.63581358e-02, 6.29581672e-02, 1.32812798e-07], + [2.60578198e-02, 3.35916645e-02, 8.28023223e-02, 4.05336238e-04], + [9.86808864e-03, 7.59774324e-04, 1.08702729e-02, 1.21359007e-01], + [2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01], + [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]]) + References ---------- @@ -133,19 +137,22 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None): Examples -------- >>> import ot + >>> np.random.seed(0) >>> n_source = 7 >>> n_target = 4 - >>> reg = 1 - >>> numItermax = 300000 >>> a = ot.utils.unif(n_source) >>> b = ot.utils.unif(n_target) - >>> rng = np.random.RandomState(0) - >>> X_source = rng.randn(n_source, 2) - >>> Y_target = rng.randn(n_target, 2) + >>> X_source = np.random.randn(n_source, 2) + >>> Y_target = np.random.randn(n_target, 2) >>> M = ot.dist(X_source, Y_target) - >>> method = "ASGD" - >>> asgd_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, numItermax) - >>> print(asgd_pi) + >>> ot.stochastic.solve_semi_dual_entropic(a, b, M, reg=1, method="ASGD", numItermax=300000) + array([[2.53942342e-02, 9.98640673e-02, 1.75945647e-02, 4.27664307e-06], + [1.21556999e-01, 1.26350515e-02, 1.30491795e-03, 7.36017394e-03], + [3.54070702e-03, 7.63581358e-02, 6.29581672e-02, 1.32812798e-07], + [2.60578198e-02, 3.35916645e-02, 8.28023223e-02, 4.05336238e-04], + [9.86808864e-03, 7.59774324e-04, 1.08702729e-02, 1.21359007e-01], + [2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01], + [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]]) References ---------- @@ -222,19 +229,22 @@ def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None): Examples -------- >>> import ot + >>> np.random.seed(0) >>> n_source = 7 >>> n_target = 4 - >>> reg = 1 - >>> numItermax = 300000 >>> a = ot.utils.unif(n_source) >>> b = ot.utils.unif(n_target) - >>> rng = np.random.RandomState(0) - >>> X_source = rng.randn(n_source, 2) - >>> Y_target = rng.randn(n_target, 2) + >>> X_source = np.random.randn(n_source, 2) + >>> Y_target = np.random.randn(n_target, 2) >>> M = ot.dist(X_source, Y_target) - >>> method = "ASGD" - >>> asgd_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, numItermax) - >>> print(asgd_pi) + >>> ot.stochastic.solve_semi_dual_entropic(a, b, M, reg=1, method="ASGD", numItermax=300000) + array([[2.53942342e-02, 9.98640673e-02, 1.75945647e-02, 4.27664307e-06], + [1.21556999e-01, 1.26350515e-02, 1.30491795e-03, 7.36017394e-03], + [3.54070702e-03, 7.63581358e-02, 6.29581672e-02, 1.32812798e-07], + [2.60578198e-02, 3.35916645e-02, 8.28023223e-02, 4.05336238e-04], + [9.86808864e-03, 7.59774324e-04, 1.08702729e-02, 1.21359007e-01], + [2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01], + [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]]) References ---------- @@ -301,19 +311,22 @@ def c_transform_entropic(b, M, reg, beta): Examples -------- >>> import ot + >>> np.random.seed(0) >>> n_source = 7 >>> n_target = 4 - >>> reg = 1 - >>> numItermax = 300000 >>> a = ot.utils.unif(n_source) >>> b = ot.utils.unif(n_target) - >>> rng = np.random.RandomState(0) - >>> X_source = rng.randn(n_source, 2) - >>> Y_target = rng.randn(n_target, 2) + >>> X_source = np.random.randn(n_source, 2) + >>> Y_target = np.random.randn(n_target, 2) >>> M = ot.dist(X_source, Y_target) - >>> method = "ASGD" - >>> asgd_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, numItermax) - >>> print(asgd_pi) + >>> ot.stochastic.solve_semi_dual_entropic(a, b, M, reg=1, method="ASGD", numItermax=300000) + array([[2.53942342e-02, 9.98640673e-02, 1.75945647e-02, 4.27664307e-06], + [1.21556999e-01, 1.26350515e-02, 1.30491795e-03, 7.36017394e-03], + [3.54070702e-03, 7.63581358e-02, 6.29581672e-02, 1.32812798e-07], + [2.60578198e-02, 3.35916645e-02, 8.28023223e-02, 4.05336238e-04], + [9.86808864e-03, 7.59774324e-04, 1.08702729e-02, 1.21359007e-01], + [2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01], + [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]]) References ---------- @@ -395,19 +408,22 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None, Examples -------- >>> import ot + >>> np.random.seed(0) >>> n_source = 7 >>> n_target = 4 - >>> reg = 1 - >>> numItermax = 300000 >>> a = ot.utils.unif(n_source) >>> b = ot.utils.unif(n_target) - >>> rng = np.random.RandomState(0) - >>> X_source = rng.randn(n_source, 2) - >>> Y_target = rng.randn(n_target, 2) + >>> X_source = np.random.randn(n_source, 2) + >>> Y_target = np.random.randn(n_target, 2) >>> M = ot.dist(X_source, Y_target) - >>> method = "ASGD" - >>> asgd_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, numItermax) - >>> print(asgd_pi) + >>> ot.stochastic.solve_semi_dual_entropic(a, b, M, reg=1, method="ASGD", numItermax=300000) + array([[2.53942342e-02, 9.98640673e-02, 1.75945647e-02, 4.27664307e-06], + [1.21556999e-01, 1.26350515e-02, 1.30491795e-03, 7.36017394e-03], + [3.54070702e-03, 7.63581358e-02, 6.29581672e-02, 1.32812798e-07], + [2.60578198e-02, 3.35916645e-02, 8.28023223e-02, 4.05336238e-04], + [9.86808864e-03, 7.59774324e-04, 1.08702729e-02, 1.21359007e-01], + [2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01], + [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]]) References ---------- @@ -502,22 +518,28 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha, Examples -------- >>> import ot + >>> np.random.seed(0) >>> n_source = 7 >>> n_target = 4 - >>> reg = 1 - >>> numItermax = 20000 - >>> lr = 0.1 - >>> batch_size = 3 - >>> log = True >>> a = ot.utils.unif(n_source) >>> b = ot.utils.unif(n_target) - >>> rng = np.random.RandomState(0) - >>> X_source = rng.randn(n_source, 2) - >>> Y_target = rng.randn(n_target, 2) + >>> X_source = np.random.randn(n_source, 2) + >>> Y_target = np.random.randn(n_target, 2) >>> M = ot.dist(X_source, Y_target) - >>> sgd_dual_pi, log = ot.stochastic.solve_dual_entropic(a, b, M, reg, batch_size, numItermax, lr, log) - >>> print(log['alpha'], log['beta']) - >>> print(sgd_dual_pi) + >>> sgd_dual_pi, log = ot.stochastic.solve_dual_entropic(a, b, M, reg=1, batch_size=3, numItermax=30000, lr=0.1, log=True) + >>> log['alpha'] + array([0.71759102, 1.57057384, 0.85576566, 0.1208211 , 0.59190466, + 1.197148 , 0.17805133]) + >>> log['beta'] + array([0.49741367, 0.57478564, 1.40075528, 2.75890102]) + >>> sgd_dual_pi + array([[2.09730063e-02, 8.38169324e-02, 7.50365455e-03, 8.72731415e-09], + [5.58432437e-03, 5.89881299e-04, 3.09558411e-05, 8.35469849e-07], + [3.26489515e-03, 7.15536035e-02, 2.99778211e-02, 3.02601593e-10], + [4.05390622e-02, 5.31085068e-02, 6.65191787e-02, 1.55812785e-06], + [7.82299812e-02, 6.12099102e-03, 4.44989098e-02, 2.37719187e-03], + [5.06266486e-02, 2.16230494e-03, 2.26215141e-03, 6.81514609e-04], + [6.06713990e-02, 3.98139808e-02, 5.46829338e-02, 8.62371424e-06]]) References ---------- @@ -526,7 +548,6 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha, International Conference on Learning Representation (2018), arXiv preprint arxiv:1711.02283. ''' - G = - (np.exp((alpha[batch_alpha, None] + beta[None, batch_beta] - M[batch_alpha, :][:, batch_beta]) / reg) * a[batch_alpha, None] * b[None, batch_beta]) @@ -605,8 +626,19 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr): >>> Y_target = rng.randn(n_target, 2) >>> M = ot.dist(X_source, Y_target) >>> sgd_dual_pi, log = ot.stochastic.solve_dual_entropic(a, b, M, reg, batch_size, numItermax, lr, log) - >>> print(log['alpha'], log['beta']) - >>> print(sgd_dual_pi) + >>> log['alpha'] + array([0.64171798, 1.27932201, 0.78132257, 0.15638935, 0.54888354, + 1.03663469, 0.20595781]) + >>> log['beta'] + array([0.51207194, 0.58033189, 1.28922676, 2.26859736]) + >>> sgd_dual_pi + array([[1.97276541e-02, 7.81248547e-02, 6.22136048e-03, 4.95442423e-09], + [4.23494310e-03, 4.43286263e-04, 2.06927079e-05, 3.82389139e-07], + [3.07542414e-03, 6.67897769e-02, 2.48904999e-02, 1.72030247e-10], + [4.26271990e-02, 5.53375455e-02, 6.16535024e-02, 9.88812650e-07], + [7.60423265e-02, 5.89585256e-03, 3.81267087e-02, 1.39458256e-03], + [4.37557504e-02, 1.85189176e-03, 1.72335760e-03, 3.55491279e-04], + [6.33096109e-02, 4.11683954e-02, 5.02962051e-02, 5.43097516e-06]]) References ---------- @@ -701,8 +733,19 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1, >>> Y_target = rng.randn(n_target, 2) >>> M = ot.dist(X_source, Y_target) >>> sgd_dual_pi, log = ot.stochastic.solve_dual_entropic(a, b, M, reg, batch_size, numItermax, lr, log) - >>> print(log['alpha'], log['beta']) - >>> print(sgd_dual_pi) + >>> log['alpha'] + array([0.64057733, 1.2683513 , 0.75610161, 0.16024284, 0.54926534, + 1.0514201 , 0.19958936]) + >>> log['beta'] + array([0.51372571, 0.58843489, 1.27993921, 2.24344807]) + >>> sgd_dual_pi + array([[1.97377795e-02, 7.86706853e-02, 6.15682001e-03, 4.82586997e-09], + [4.19566963e-03, 4.42016865e-04, 2.02777272e-05, 3.68823708e-07], + [3.00379244e-03, 6.56562018e-02, 2.40462171e-02, 1.63579656e-10], + [4.28626062e-02, 5.60031599e-02, 6.13193826e-02, 9.67977735e-07], + [7.61972739e-02, 5.94609051e-03, 3.77886693e-02, 1.36046648e-03], + [4.44810042e-02, 1.89476742e-03, 1.73285847e-03, 3.51826036e-04], + [6.30118293e-02, 4.12398660e-02, 4.95148998e-02, 5.26247246e-06]]) References ---------- -- cgit v1.2.3