summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-02-20 16:11:56 +0100
committerRémi Flamary <remi.flamary@gmail.com>2018-02-20 16:11:56 +0100
commit6d9b281271167d3676538f2ef8518abea82ef9c8 (patch)
tree7d1fae1d15a0ec70e229819a68b9f3a1ceea8f02 /examples
parent806a406e1ca2e9ca0bfdfe0516c75865e8098205 (diff)
parent5ff8030ce300f3d066e1edba2b36e60709b023b8 (diff)
Merge branch 'master' of github.com:rflamary/POT
Diffstat (limited to 'examples')
-rw-r--r--examples/plot_OT_1D.py1
-rw-r--r--examples/plot_OT_L1_vs_L2.py3
-rw-r--r--examples/plot_barycenter_1D.py8
-rw-r--r--examples/plot_gromov.py39
-rwxr-xr-xexamples/plot_gromov_barycenter.py8
-rw-r--r--examples/plot_optim_OTreg.py2
-rw-r--r--examples/plot_otda_d2.py2
7 files changed, 39 insertions, 24 deletions
diff --git a/examples/plot_OT_1D.py b/examples/plot_OT_1D.py
index 719058f..90325c9 100644
--- a/examples/plot_OT_1D.py
+++ b/examples/plot_OT_1D.py
@@ -16,6 +16,7 @@ and their visualization.
import numpy as np
import matplotlib.pylab as pl
import ot
+import ot.plot
from ot.datasets import get_1D_gauss as gauss
##############################################################################
diff --git a/examples/plot_OT_L1_vs_L2.py b/examples/plot_OT_L1_vs_L2.py
index 090e809..37b429f 100644
--- a/examples/plot_OT_L1_vs_L2.py
+++ b/examples/plot_OT_L1_vs_L2.py
@@ -19,6 +19,7 @@ https://arxiv.org/pdf/1706.07650.pdf
import numpy as np
import matplotlib.pylab as pl
import ot
+import ot.plot
##############################################################################
# Dataset 1 : uniform sampling
@@ -52,7 +53,7 @@ pl.clf()
pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.axis('equal')
-pl.title('Source and traget distributions')
+pl.title('Source and target distributions')
# Cost matrices
diff --git a/examples/plot_barycenter_1D.py b/examples/plot_barycenter_1D.py
index 620936b..ecf640c 100644
--- a/examples/plot_barycenter_1D.py
+++ b/examples/plot_barycenter_1D.py
@@ -25,7 +25,7 @@ import ot
from mpl_toolkits.mplot3d import Axes3D # noqa
from matplotlib.collections import PolyCollection
-##############################################################################
+#
# Generate data
# -------------
@@ -48,7 +48,7 @@ n_distributions = A.shape[1]
M = ot.utils.dist0(n)
M /= M.max()
-##############################################################################
+#
# Plot data
# ---------
@@ -60,7 +60,7 @@ for i in range(n_distributions):
pl.title('Distributions')
pl.tight_layout()
-##############################################################################
+#
# Barycenter computation
# ----------------------
@@ -90,7 +90,7 @@ pl.legend()
pl.title('Barycenters')
pl.tight_layout()
-##############################################################################
+#
# Barycentric interpolation
# -------------------------
diff --git a/examples/plot_gromov.py b/examples/plot_gromov.py
index d3f724c..5cd40f6 100644
--- a/examples/plot_gromov.py
+++ b/examples/plot_gromov.py
@@ -19,8 +19,8 @@ import matplotlib.pylab as pl
from mpl_toolkits.mplot3d import Axes3D # noqa
import ot
-
-##############################################################################
+#############################################################################
+#
# Sample two Gaussian distributions (2D and 3D)
# ---------------------------------------------
#
@@ -42,8 +42,8 @@ xs = ot.datasets.get_2D_samples_gauss(n_samples, mu_s, cov_s)
P = sp.linalg.sqrtm(cov_t)
xt = np.random.randn(n_samples, 3).dot(P) + mu_t
-
-##############################################################################
+#############################################################################
+#
# Plotting the distributions
# --------------------------
@@ -55,8 +55,8 @@ ax2 = fig.add_subplot(122, projection='3d')
ax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color='r')
pl.show()
-
-##############################################################################
+#############################################################################
+#
# Compute distance kernels, normalize them and then display
# ---------------------------------------------------------
@@ -74,20 +74,33 @@ pl.subplot(122)
pl.imshow(C2)
pl.show()
-##############################################################################
+#############################################################################
+#
# Compute Gromov-Wasserstein plans and distance
# ---------------------------------------------
-
p = ot.unif(n_samples)
q = ot.unif(n_samples)
-gw = ot.gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4)
-gw_dist = ot.gromov_wasserstein2(C1, C2, p, q, 'square_loss', epsilon=5e-4)
+gw0, log0 = ot.gromov.gromov_wasserstein(
+ C1, C2, p, q, 'square_loss', verbose=True, log=True)
-print('Gromov-Wasserstein distances between the distribution: ' + str(gw_dist))
+gw, log = ot.gromov.entropic_gromov_wasserstein(
+ C1, C2, p, q, 'square_loss', epsilon=5e-4, log=True, verbose=True)
-pl.figure()
+
+print('Gromov-Wasserstein distances: ' + str(log0['gw_dist']))
+print('Entropic Gromov-Wasserstein distances: ' + str(log['gw_dist']))
+
+
+pl.figure(1, (10, 5))
+
+pl.subplot(1, 2, 1)
+pl.imshow(gw0, cmap='jet')
+pl.title('Gromov Wasserstein')
+
+pl.subplot(1, 2, 2)
pl.imshow(gw, cmap='jet')
-pl.colorbar()
+pl.title('Entropic Gromov Wasserstein')
+
pl.show()
diff --git a/examples/plot_gromov_barycenter.py b/examples/plot_gromov_barycenter.py
index 180b0cf..58fc51a 100755
--- a/examples/plot_gromov_barycenter.py
+++ b/examples/plot_gromov_barycenter.py
@@ -132,28 +132,28 @@ Ct01 = [0 for i in range(2)]
for i in range(2):
Ct01[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[1]],
[ps[0], ps[1]
- ], p, lambdast[i], 'square_loss', 5e-4,
+ ], p, lambdast[i], 'square_loss', # 5e-4,
max_iter=100, tol=1e-3)
Ct02 = [0 for i in range(2)]
for i in range(2):
Ct02[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[2]],
[ps[0], ps[2]
- ], p, lambdast[i], 'square_loss', 5e-4,
+ ], p, lambdast[i], 'square_loss', # 5e-4,
max_iter=100, tol=1e-3)
Ct13 = [0 for i in range(2)]
for i in range(2):
Ct13[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[1], Cs[3]],
[ps[1], ps[3]
- ], p, lambdast[i], 'square_loss', 5e-4,
+ ], p, lambdast[i], 'square_loss', # 5e-4,
max_iter=100, tol=1e-3)
Ct23 = [0 for i in range(2)]
for i in range(2):
Ct23[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[2], Cs[3]],
[ps[2], ps[3]
- ], p, lambdast[i], 'square_loss', 5e-4,
+ ], p, lambdast[i], 'square_loss', # 5e-4,
max_iter=100, tol=1e-3)
diff --git a/examples/plot_optim_OTreg.py b/examples/plot_optim_OTreg.py
index e1a737e..92df016 100644
--- a/examples/plot_optim_OTreg.py
+++ b/examples/plot_optim_OTreg.py
@@ -28,7 +28,7 @@ arXiv preprint arXiv:1510.06567.
import numpy as np
import matplotlib.pylab as pl
import ot
-
+import ot.plot
##############################################################################
# Generate data
diff --git a/examples/plot_otda_d2.py b/examples/plot_otda_d2.py
index e53d7d6..70beb35 100644
--- a/examples/plot_otda_d2.py
+++ b/examples/plot_otda_d2.py
@@ -20,7 +20,7 @@ of what the transport methods are doing.
import matplotlib.pylab as pl
import ot
-
+import ot.plot
##############################################################################
# generate data