summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/demo_OT_1D.py49
-rw-r--r--ot/__init__.py2
-rw-r--r--ot/plot.py40
3 files changed, 48 insertions, 43 deletions
diff --git a/examples/demo_OT_1D.py b/examples/demo_OT_1D.py
index 8ba40ba..b17f902 100644
--- a/examples/demo_OT_1D.py
+++ b/examples/demo_OT_1D.py
@@ -7,7 +7,7 @@ Created on Fri Oct 21 09:51:45 2016
import numpy as np
import matplotlib.pylab as pl
-from matplotlib import gridspec
+
import ot
@@ -16,18 +16,12 @@ import ot
n=100 # nb bins
-ma=20 # mean of a
-mb=60 # mean of b
-
-sa=20 # std of a
-sb=60 # std of b
-
# bin positions
x=np.arange(n,dtype=np.float64)
# Gaussian distributions
-a=ot.datasets.get_1D_gauss(n,ma,sa)
-b=ot.datasets.get_1D_gauss(n,mb,sb)
+a=ot.datasets.get_1D_gauss(n,m=20,s=20) # m= mean, s= std
+b=ot.datasets.get_1D_gauss(n,m=60,s=60)
# loss matrix
M=ot.dist(x.reshape((n,1)),x.reshape((n,1)))
@@ -44,47 +38,16 @@ pl.legend()
#%% plot distributions and loss matrix
-def plotmat(M,title=''):
- """ Plot a matrix woth the 1D distribution """
- gs = gridspec.GridSpec(3, 3)
-
- ax1=pl.subplot(gs[0,1:])
- pl.plot(x,b,'r',label='Target distribution')
- pl.yticks(())
- pl.title(title)
-
- #pl.axis('off')
-
- ax2=pl.subplot(gs[1:,0])
- pl.plot(a,x,'b',label='Source distribution')
- pl.gca().invert_xaxis()
- pl.gca().invert_yaxis()
- pl.xticks(())
- #pl.ylim((0,n))
- #pl.axis('off')
-
- pl.subplot(gs[1:,1:],sharex=ax1,sharey=ax2)
- pl.imshow(M,interpolation='nearest')
-
- pl.xlim((0,n))
-
pl.figure(2)
+ot.plot.otplot1D(a,b,M,'Cost matrix M')
-plotmat(M,'Cost matrix M')
-
-
-
-
-
-#pl.ylim((0,n))
-#pl.axis('off')
#%% EMD
G0=ot.emd(a,b,M)
pl.figure(3)
-plotmat(G0,'OT matrix G0')
+ot.plot.otplot1D(a,b,G0,'OT matrix G0')
#%% Sinkhorn
lambd=1e-3
@@ -92,4 +55,4 @@ lambd=1e-3
Gs=ot.sinkhorn(a,b,M,lambd)
pl.figure(4)
-plotmat(Gs,'OT matrix Sinkhorn')
+ot.plot.otplot1D(a,b,Gs,'OT matrix Sinkhorn')
diff --git a/ot/__init__.py b/ot/__init__.py
index 14c6181..fe55771 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -2,6 +2,8 @@
# utils submodules
import utils
import datasets
+import plot
+
# Ot functions
from emd import emd
diff --git a/ot/plot.py b/ot/plot.py
new file mode 100644
index 0000000..b8357b1
--- /dev/null
+++ b/ot/plot.py
@@ -0,0 +1,40 @@
+
+import numpy as np
+import matplotlib.pylab as pl
+from matplotlib import gridspec
+
+
+def otplot1D(a,b,M,title=''):
+ """ Plot a with the source and target 1D distribution """
+
+ na=M.shape[0]
+ nb=M.shape[1]
+
+ gs = gridspec.GridSpec(3, 3)
+
+
+ xa=np.arange(na)
+ xb=np.arange(nb)
+
+
+ ax1=pl.subplot(gs[0,1:])
+ pl.plot(xb,b,'r',label='Target distribution')
+ pl.yticks(())
+ pl.title(title)
+
+ #pl.axis('off')
+
+ ax2=pl.subplot(gs[1:,0])
+ pl.plot(a,xa,'b',label='Source distribution')
+ pl.gca().invert_xaxis()
+ pl.gca().invert_yaxis()
+ pl.xticks(())
+ #pl.ylim((0,n))
+ #pl.axis('off')
+
+ pl.subplot(gs[1:,1:],sharex=ax1,sharey=ax2)
+ pl.imshow(M,interpolation='nearest')
+
+ pl.xlim((0,nb))
+
+