summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2023-06-09 20:26:52 +0200
committerGitHub <noreply@github.com>2023-06-09 20:26:52 +0200
commit6c1e1f3e064165d37e22acc866c6fff56e3ab6ad (patch)
tree354aa9c554a9e7490c93bd2ac579675f1b933329 /examples
parent5faa4fbdb1a64351a42d31dd6f54f0402c29c405 (diff)
[MRG] Update tests and documentation (#484)
* remove old macos and windows tets update requirements * speedup ssw and continuaous ot exmaples * speedup regpath and variane * speedup conv 2d example + continuous stick * speedup regpath
Diffstat (limited to 'examples')
-rw-r--r--examples/backends/plot_sliced_wass_grad_flow_pytorch.py8
-rw-r--r--examples/backends/plot_ssw_unif_torch.py12
-rw-r--r--examples/backends/plot_stoch_continuous_ot_pytorch.py6
-rw-r--r--examples/barycenters/plot_convolutional_barycenter.py8
-rw-r--r--examples/plot_OT_1D_smooth.py2
-rw-r--r--examples/sliced-wasserstein/plot_variance.py8
-rw-r--r--examples/sliced-wasserstein/plot_variance_ssw.py8
-rw-r--r--examples/unbalanced-partial/plot_regpath.py14
8 files changed, 33 insertions, 33 deletions
diff --git a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
index 07a4926..7cbfd98 100644
--- a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
+++ b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
@@ -40,8 +40,8 @@ import torch
import ot
import matplotlib.animation as animation
-I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::4, ::4, 2]
-I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::4, ::4, 2]
+I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::5, ::5, 2]
+I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::5, ::5, 2]
sz = I2.shape[0]
XX, YY = np.meshgrid(np.arange(sz), np.arange(sz))
@@ -67,7 +67,7 @@ x2_torch = torch.tensor(x2).to(device=device)
lr = 1e3
-nb_iter_max = 100
+nb_iter_max = 50
x_all = np.zeros((nb_iter_max, x1.shape[0], 2))
@@ -129,7 +129,7 @@ xbinit = np.random.randn(500, 2) * 10 + 16
xbary_torch = torch.tensor(xbinit).to(device=device).requires_grad_(True)
lr = 1e3
-nb_iter_max = 100
+nb_iter_max = 50
x_all = np.zeros((nb_iter_max, xbary_torch.shape[0], 2))
diff --git a/examples/backends/plot_ssw_unif_torch.py b/examples/backends/plot_ssw_unif_torch.py
index afe3fa6..7459cf6 100644
--- a/examples/backends/plot_ssw_unif_torch.py
+++ b/examples/backends/plot_ssw_unif_torch.py
@@ -35,7 +35,7 @@ import ot
torch.manual_seed(1)
-N = 1000
+N = 500
x0 = torch.rand(N, 3)
x0 = F.normalize(x0, dim=-1)
@@ -72,8 +72,8 @@ ax.legend()
x = x0.clone()
x.requires_grad_(True)
-n_iter = 500
-lr = 100
+n_iter = 100
+lr = 150
losses = []
xvisu = torch.zeros(n_iter, N, 3)
@@ -82,7 +82,7 @@ for i in range(n_iter):
sw = ot.sliced_wasserstein_sphere_unif(x, n_projections=500)
grad_x = torch.autograd.grad(sw, x)[0]
- x = x - lr * grad_x
+ x = x - lr * grad_x / np.sqrt(i / 10 + 1)
x = F.normalize(x, p=2, dim=1)
losses.append(sw.item())
@@ -102,7 +102,7 @@ pl.xlabel("Iterations")
# Plot trajectories of generated samples along iterations
# -------------------------------------------------------
-ivisu = [0, 25, 50, 75, 100, 150, 200, 350, 499]
+ivisu = [0, 10, 20, 30, 40, 50, 60, 70, 80]
fig = pl.figure(3, (10, 10))
for i in range(9):
@@ -149,5 +149,5 @@ ax.set_ylim((-1.5, 1.5))
ax.set_title('Iter. {}'.format(ivisu[i]))
-ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter // 5, interval=100, repeat_delay=2000)
+ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter // 5, interval=200, repeat_delay=2000)
# %%
diff --git a/examples/backends/plot_stoch_continuous_ot_pytorch.py b/examples/backends/plot_stoch_continuous_ot_pytorch.py
index 714a5d3..e642986 100644
--- a/examples/backends/plot_stoch_continuous_ot_pytorch.py
+++ b/examples/backends/plot_stoch_continuous_ot_pytorch.py
@@ -27,8 +27,8 @@ import ot.plot
torch.manual_seed(42)
np.random.seed(42)
-n_source_samples = 10000
-n_target_samples = 10000
+n_source_samples = 1000
+n_target_samples = 1000
theta = 2 * np.pi / 20
noise_level = 0.1
@@ -89,7 +89,7 @@ reg = 1
optimizer = torch.optim.Adam(list(u.parameters()) + list(v.parameters()), lr=.005)
# number of iteration
-n_iter = 1000
+n_iter = 500
n_batch = 500
diff --git a/examples/barycenters/plot_convolutional_barycenter.py b/examples/barycenters/plot_convolutional_barycenter.py
index 3721f31..143b3a6 100644
--- a/examples/barycenters/plot_convolutional_barycenter.py
+++ b/examples/barycenters/plot_convolutional_barycenter.py
@@ -29,10 +29,10 @@ import ot
this_file = os.path.realpath('__file__')
data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
-f1 = 1 - plt.imread(os.path.join(data_path, 'redcross.png'))[:, :, 2]
-f2 = 1 - plt.imread(os.path.join(data_path, 'tooth.png'))[:, :, 2]
-f3 = 1 - plt.imread(os.path.join(data_path, 'heart.png'))[:, :, 2]
-f4 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[:, :, 2]
+f1 = 1 - plt.imread(os.path.join(data_path, 'redcross.png'))[::2, ::2, 2]
+f2 = 1 - plt.imread(os.path.join(data_path, 'tooth.png'))[::2, ::2, 2]
+f3 = 1 - plt.imread(os.path.join(data_path, 'heart.png'))[::2, ::2, 2]
+f4 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[::2, ::2, 2]
f1 = f1 / np.sum(f1)
f2 = f2 / np.sum(f2)
diff --git a/examples/plot_OT_1D_smooth.py b/examples/plot_OT_1D_smooth.py
index 626938c..4f233fe 100644
--- a/examples/plot_OT_1D_smooth.py
+++ b/examples/plot_OT_1D_smooth.py
@@ -14,7 +14,7 @@ sparsity-constrained OT, together with their visualizations.
#
# License: MIT License
-# sphinx_gallery_thumbnail_number = 6
+# sphinx_gallery_thumbnail_number = 5
import numpy as np
import matplotlib.pylab as pl
diff --git a/examples/sliced-wasserstein/plot_variance.py b/examples/sliced-wasserstein/plot_variance.py
index 2293247..77df2f5 100644
--- a/examples/sliced-wasserstein/plot_variance.py
+++ b/examples/sliced-wasserstein/plot_variance.py
@@ -29,7 +29,7 @@ import ot
# %% parameters and data generation
-n = 500 # nb samples
+n = 200 # nb samples
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
@@ -58,9 +58,9 @@ pl.title('Source and target distributions')
# Sliced Wasserstein distance for different seeds and number of projections
# -------------------------------------------------------------------------
-n_seed = 50
-n_projections_arr = np.logspace(0, 3, 25, dtype=int)
-res = np.empty((n_seed, 25))
+n_seed = 20
+n_projections_arr = np.logspace(0, 3, 10, dtype=int)
+res = np.empty((n_seed, 10))
# %% Compute statistics
for seed in range(n_seed):
diff --git a/examples/sliced-wasserstein/plot_variance_ssw.py b/examples/sliced-wasserstein/plot_variance_ssw.py
index f5fc35f..246b2a8 100644
--- a/examples/sliced-wasserstein/plot_variance_ssw.py
+++ b/examples/sliced-wasserstein/plot_variance_ssw.py
@@ -28,7 +28,7 @@ import ot
# %% parameters and data generation
-n = 500 # nb samples
+n = 200 # nb samples
xs = np.random.randn(n, 3)
xt = np.random.randn(n, 3)
@@ -81,9 +81,9 @@ pl.title("Source and Target distribution")
# Spherical Sliced Wasserstein for different seeds and number of projections
# --------------------------------------------------------------------------
-n_seed = 50
-n_projections_arr = np.logspace(0, 3, 25, dtype=int)
-res = np.empty((n_seed, 25))
+n_seed = 20
+n_projections_arr = np.logspace(0, 3, 10, dtype=int)
+res = np.empty((n_seed, 10))
# %% Compute statistics
for seed in range(n_seed):
diff --git a/examples/unbalanced-partial/plot_regpath.py b/examples/unbalanced-partial/plot_regpath.py
index d1f2042..ffedc6e 100644
--- a/examples/unbalanced-partial/plot_regpath.py
+++ b/examples/unbalanced-partial/plot_regpath.py
@@ -27,7 +27,7 @@ import matplotlib.animation as animation
#%% parameters and data generation
-n = 50 # nb samples
+n = 20 # nb samples
mu_s = np.array([-1, -1])
cov_s = np.array([[1, 0], [0, 1]])
@@ -63,7 +63,7 @@ pl.show()
# -----------------------------------------------------------
#%%
-final_gamma = 1e-8
+final_gamma = 1e-6
t, t_list, g_list = ot.regpath.regularization_path(a, b, M, reg=final_gamma,
semi_relaxed=False)
t2, t_list2, g_list2 = ot.regpath.regularization_path(a, b, M, reg=final_gamma,
@@ -111,7 +111,7 @@ pl.show()
# Animation of the regpath for UOT l2
# -----------------------------------
-nv = 100
+nv = 50
g_list_v = np.logspace(-.5, -2.5, nv)
pl.figure(3)
@@ -144,7 +144,7 @@ def _update_plot(iv):
i = 0
_update_plot(i)
-ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=50, repeat_delay=2000)
+ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=100, repeat_delay=2000)
##############################################################################
@@ -183,8 +183,8 @@ pl.show()
# Animation of the regpath for semi-relaxed UOT l2
# ------------------------------------------------
-nv = 100
-g_list_v = np.logspace(2.5, -2, nv)
+nv = 50
+g_list_v = np.logspace(2, -2, nv)
pl.figure(5)
@@ -216,4 +216,4 @@ def _update_plot(iv):
i = 0
_update_plot(i)
-ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=50, repeat_delay=2000)
+ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=100, repeat_delay=2000)