summaryrefslogtreecommitdiff
path: root/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/backends/plot_sliced_wass_grad_flow_pytorch.py')
-rw-r--r--examples/backends/plot_sliced_wass_grad_flow_pytorch.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
index 07a4926..7cbfd98 100644
--- a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
+++ b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
@@ -40,8 +40,8 @@ import torch
import ot
import matplotlib.animation as animation
-I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::4, ::4, 2]
-I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::4, ::4, 2]
+I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::5, ::5, 2]
+I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::5, ::5, 2]
sz = I2.shape[0]
XX, YY = np.meshgrid(np.arange(sz), np.arange(sz))
@@ -67,7 +67,7 @@ x2_torch = torch.tensor(x2).to(device=device)
lr = 1e3
-nb_iter_max = 100
+nb_iter_max = 50
x_all = np.zeros((nb_iter_max, x1.shape[0], 2))
@@ -129,7 +129,7 @@ xbinit = np.random.randn(500, 2) * 10 + 16
xbary_torch = torch.tensor(xbinit).to(device=device).requires_grad_(True)
lr = 1e3
-nb_iter_max = 100
+nb_iter_max = 50
x_all = np.zeros((nb_iter_max, xbary_torch.shape[0], 2))