summaryrefslogtreecommitdiff
path: root/examples/backends/plot_stoch_continuous_ot_pytorch.py
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2023-06-09 20:26:52 +0200
committerGitHub <noreply@github.com>2023-06-09 20:26:52 +0200
commit6c1e1f3e064165d37e22acc866c6fff56e3ab6ad (patch)
tree354aa9c554a9e7490c93bd2ac579675f1b933329 /examples/backends/plot_stoch_continuous_ot_pytorch.py
parent5faa4fbdb1a64351a42d31dd6f54f0402c29c405 (diff)
[MRG] Update tests and documentation (#484)
* remove old macos and windows tets update requirements * speedup ssw and continuaous ot exmaples * speedup regpath and variane * speedup conv 2d example + continuous stick * speedup regpath
Diffstat (limited to 'examples/backends/plot_stoch_continuous_ot_pytorch.py')
-rw-r--r--examples/backends/plot_stoch_continuous_ot_pytorch.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/examples/backends/plot_stoch_continuous_ot_pytorch.py b/examples/backends/plot_stoch_continuous_ot_pytorch.py
index 714a5d3..e642986 100644
--- a/examples/backends/plot_stoch_continuous_ot_pytorch.py
+++ b/examples/backends/plot_stoch_continuous_ot_pytorch.py
@@ -27,8 +27,8 @@ import ot.plot
torch.manual_seed(42)
np.random.seed(42)
-n_source_samples = 10000
-n_target_samples = 10000
+n_source_samples = 1000
+n_target_samples = 1000
theta = 2 * np.pi / 20
noise_level = 0.1
@@ -89,7 +89,7 @@ reg = 1
optimizer = torch.optim.Adam(list(u.parameters()) + list(v.parameters()), lr=.005)
# number of iteration
-n_iter = 1000
+n_iter = 500
n_batch = 500