From ad02112d4288f3efdd5bc6fc6e45444313bba871 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Tue, 5 Apr 2022 11:57:10 +0200 Subject: [MRG] Update examples in the doc (#359) * add transparent color logo * add transparent color logo * move screenkhorn * move stochastic and install ffmpeg on circleci * try something * add sudo * install ffmpeg before python * cleanup examples * test svg scrapper * add animation for reg path * better example OT sivergence * update ttles and add plots * update free support * proper figure indexes * have less frame sin animation * update readme and release file * add tests for python 3.10 --- examples/plot_OT_L1_vs_L2.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) (limited to 'examples/plot_OT_L1_vs_L2.py') diff --git a/examples/plot_OT_L1_vs_L2.py b/examples/plot_OT_L1_vs_L2.py index cb94574..cce51f8 100644 --- a/examples/plot_OT_L1_vs_L2.py +++ b/examples/plot_OT_L1_vs_L2.py @@ -1,10 +1,10 @@ # -*- coding: utf-8 -*- """ -========================================== -2D Optimal transport for different metrics -========================================== +================================================ +Optimal Transport with different gournd metrics +================================================ -2D OT on empirical distributio with different gound metric. +2D OT on empirical distributio with different ground metric. Stole the figure idea from Fig. 1 and 2 in https://arxiv.org/pdf/1706.07650.pdf @@ -23,7 +23,7 @@ import matplotlib.pylab as pl import ot import ot.plot -############################################################################## +# %% # Dataset 1 : uniform sampling # ---------------------------- @@ -46,7 +46,7 @@ M2 = ot.dist(xs, xt, metric='sqeuclidean') M2 /= M2.max() # loss matrix -Mp = np.sqrt(ot.dist(xs, xt, metric='euclidean')) +Mp = ot.dist(xs, xt, metric='cityblock') Mp /= Mp.max() # Data @@ -71,7 +71,7 @@ pl.title('Squared Euclidean cost') pl.subplot(1, 3, 3) pl.imshow(Mp, interpolation='nearest') -pl.title('Sqrt Euclidean cost') +pl.title('L1 (cityblock cost') pl.tight_layout() ############################################################################## @@ -109,22 +109,22 @@ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') pl.axis('equal') # pl.legend(loc=0) -pl.title('OT sqrt Euclidean') +pl.title('OT L1 (cityblock)') pl.tight_layout() pl.show() -############################################################################## +# %% # Dataset 2 : Partial circle # -------------------------- -n = 50 # nb samples +n = 20 # nb samples xtot = np.zeros((n + 1, 2)) xtot[:, 0] = np.cos( - (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi) + (np.arange(n + 1) + 1.0) * 0.8 / (n + 2) * 2 * np.pi) xtot[:, 1] = np.sin( - (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi) + (np.arange(n + 1) + 1.0) * 0.8 / (n + 2) * 2 * np.pi) xs = xtot[:n, :] xt = xtot[1:, :] @@ -140,7 +140,7 @@ M2 = ot.dist(xs, xt, metric='sqeuclidean') M2 /= M2.max() # loss matrix -Mp = np.sqrt(ot.dist(xs, xt, metric='euclidean')) +Mp = ot.dist(xs, xt, metric='cityblock') Mp /= Mp.max() @@ -166,13 +166,13 @@ pl.title('Squared Euclidean cost') pl.subplot(1, 3, 3) pl.imshow(Mp, interpolation='nearest') -pl.title('Sqrt Euclidean cost') +pl.title('L1 (cityblock) cost') pl.tight_layout() ############################################################################## # Dataset 2 : Plot OT Matrices # ----------------------------- - +# #%% EMD G1 = ot.emd(a, b, M1) @@ -204,7 +204,7 @@ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') pl.axis('equal') # pl.legend(loc=0) -pl.title('OT sqrt Euclidean') +pl.title('OT L1 (cityblock)') pl.tight_layout() pl.show() -- cgit v1.2.3