summaryrefslogtreecommitdiff
path: root/examples/plot_gromov_barycenter.py
diff options
context:
space:
mode:
authorNicolas Courty <Nico@MacBook-Pro-de-Nicolas.local>2017-09-01 11:22:13 +0200
committerNicolas Courty <Nico@MacBook-Pro-de-Nicolas.local>2017-09-01 11:22:13 +0200
commit64a5d3c4e49688c13d236baf9ed23420070024d6 (patch)
treeffe5db073c07e579b26ead6a8ebcb0ff78ce6a33 /examples/plot_gromov_barycenter.py
parentab6ed1df93cd78bb7f1a54282103d4d830e68bcb (diff)
parent986f46ddde3ce2f550cb56f66620df377326423d (diff)
docstrings and naming
Diffstat (limited to 'examples/plot_gromov_barycenter.py')
-rwxr-xr-xexamples/plot_gromov_barycenter.py33
1 files changed, 33 insertions, 0 deletions
diff --git a/examples/plot_gromov_barycenter.py b/examples/plot_gromov_barycenter.py
index f0657e1..46ec4bc 100755
--- a/examples/plot_gromov_barycenter.py
+++ b/examples/plot_gromov_barycenter.py
@@ -91,12 +91,21 @@ def im2mat(I):
return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
+<<<<<<< HEAD
square = spi.imread('../data/carre.png').astype(np.float64) / 256
circle = spi.imread('../data/rond.png').astype(np.float64) / 256
triangle = spi.imread('../data/triangle.png').astype(np.float64) / 256
arrow = spi.imread('../data/coeur.png').astype(np.float64) / 256
shapes = [square, circle, triangle, arrow]
+=======
+carre = spi.imread('../data/carre.png').astype(np.float64) / 256
+rond = spi.imread('../data/rond.png').astype(np.float64) / 256
+triangle = spi.imread('../data/triangle.png').astype(np.float64) / 256
+fleche = spi.imread('../data/coeur.png').astype(np.float64) / 256
+
+shapes = [carre, rond, triangle, fleche]
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
S = 4
xs = [[] for i in range(S)]
@@ -118,36 +127,60 @@ Barycenter computation
The four distributions are constructed from 4 simple images
"""
ns = [len(xs[s]) for s in range(S)]
+<<<<<<< HEAD
n_samples = 30
+=======
+N = 30
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
"""Compute all distances matrices for the four shapes"""
Cs = [sp.spatial.distance.cdist(xs[s], xs[s]) for s in range(S)]
Cs = [cs / cs.max() for cs in Cs]
ps = [ot.unif(ns[s]) for s in range(S)]
+<<<<<<< HEAD
p = ot.unif(n_samples)
+=======
+p = ot.unif(N)
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
lambdast = [[float(i) / 3, float(3 - i) / 3] for i in [1, 2]]
Ct01 = [0 for i in range(2)]
for i in range(2):
+<<<<<<< HEAD
Ct01[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[1]], [
+=======
+ Ct01[i] = ot.gromov.gromov_barycenters(N, [Cs[0], Cs[1]], [
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
ps[0], ps[1]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)
Ct02 = [0 for i in range(2)]
for i in range(2):
+<<<<<<< HEAD
Ct02[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[2]], [
+=======
+ Ct02[i] = ot.gromov.gromov_barycenters(N, [Cs[0], Cs[2]], [
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
ps[0], ps[2]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)
Ct13 = [0 for i in range(2)]
for i in range(2):
+<<<<<<< HEAD
Ct13[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[1], Cs[3]], [
+=======
+ Ct13[i] = ot.gromov.gromov_barycenters(N, [Cs[1], Cs[3]], [
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
ps[1], ps[3]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)
Ct23 = [0 for i in range(2)]
for i in range(2):
+<<<<<<< HEAD
Ct23[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[2], Cs[3]], [
+=======
+ Ct23[i] = ot.gromov.gromov_barycenters(N, [Cs[2], Cs[3]], [
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
ps[2], ps[3]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)
"""