diff options
Diffstat (limited to 'examples/backends/plot_sliced_wass_grad_flow_pytorch.py')
-rw-r--r-- | examples/backends/plot_sliced_wass_grad_flow_pytorch.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py index 07a4926..7cbfd98 100644 --- a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py +++ b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py @@ -40,8 +40,8 @@ import torch import ot import matplotlib.animation as animation -I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::4, ::4, 2] -I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::4, ::4, 2] +I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::5, ::5, 2] +I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::5, ::5, 2] sz = I2.shape[0] XX, YY = np.meshgrid(np.arange(sz), np.arange(sz)) @@ -67,7 +67,7 @@ x2_torch = torch.tensor(x2).to(device=device) lr = 1e3 -nb_iter_max = 100 +nb_iter_max = 50 x_all = np.zeros((nb_iter_max, x1.shape[0], 2)) @@ -129,7 +129,7 @@ xbinit = np.random.randn(500, 2) * 10 + 16 xbary_torch = torch.tensor(xbinit).to(device=device).requires_grad_(True) lr = 1e3 -nb_iter_max = 100 +nb_iter_max = 50 x_all = np.zeros((nb_iter_max, xbary_torch.shape[0], 2)) |