summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-10-21 10:51:27 +0200
committerRémi Flamary <remi.flamary@gmail.com>2016-10-21 10:51:27 +0200
commit581c6de782dca279edd97778cc474e7597788c0f (patch)
tree760161e1c7812d8caf77bf8acc543453c6213e39
parent2109443f5bea396114d1f9e0563ba5c396378c57 (diff)
demo+sinkhorn
-rw-r--r--examples/demo_OT_1D.py102
-rw-r--r--ot/__init__.py6
-rw-r--r--ot/bregman/__init__.py4
-rw-r--r--ot/bregman/sinkhorn.py91
-rw-r--r--ot/datasets.py10
-rw-r--r--ot/emd/emd.pyx16
-rw-r--r--ot/utils.py15
-rwxr-xr-xsetup.py4
8 files changed, 246 insertions, 2 deletions
diff --git a/examples/demo_OT_1D.py b/examples/demo_OT_1D.py
new file mode 100644
index 0000000..29f2074
--- /dev/null
+++ b/examples/demo_OT_1D.py
@@ -0,0 +1,102 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Fri Oct 21 09:51:45 2016
+
+@author: rflamary
+"""
+
+import numpy as np
+import matplotlib.pylab as pl
+from matplotlib import gridspec
+import ot
+
+
+
+#%% parameters
+
+n=100 # nb bins
+
+ma=20 # mean of a
+mb=60 # mean of b
+
+sa=20 # std of a
+sb=60 # std of b
+
+# bin positions
+x=np.arange(n,dtype=np.float64)
+
+# Gaussian distributions
+a=np.exp(-(x-ma)**2/(2*sa^2))
+b=np.exp(-(x-mb)**2/(2*sb^2))
+
+# normalization
+a/=a.sum()
+b/=b.sum()
+
+# loss matrix
+M=ot.dist(x.reshape((n,1)),x.reshape((n,1)))
+M/=M.max()
+
+#%% plot the distributions
+
+pl.figure(1)
+
+pl.plot(x,a,'b',label='Source distribution')
+pl.plot(x,b,'r',label='Target distribution')
+
+pl.legend()
+
+#%% plot distributions and loss matrix
+
+pl.figure(2)
+gs = gridspec.GridSpec(3, 3)
+
+ax1=pl.subplot(gs[0,1:])
+pl.plot(x,b,'r',label='Target distribution')
+pl.yticks(())
+
+#pl.axis('off')
+
+ax2=pl.subplot(gs[1:,0])
+pl.plot(a,x,'b',label='Source distribution')
+pl.gca().invert_xaxis()
+pl.gca().invert_yaxis()
+pl.xticks(())
+#pl.ylim((0,n))
+#pl.axis('off')
+
+pl.subplot(gs[1:,1:],sharex=ax1,sharey=ax2)
+pl.imshow(M,interpolation='nearest')
+
+pl.xlim((0,n))
+#pl.ylim((0,n))
+#pl.axis('off')
+
+#%% EMD
+
+G0=ot.emd(a,b,M)
+
+#%% plot EMD optimal tranport matrix
+pl.figure(3)
+gs = gridspec.GridSpec(3, 3)
+
+ax1=pl.subplot(gs[0,1:])
+pl.plot(x,b,'r',label='Target distribution')
+pl.yticks(())
+
+#pl.axis('off')
+
+ax2=pl.subplot(gs[1:,0])
+pl.plot(a,x,'b',label='Source distribution')
+pl.gca().invert_xaxis()
+pl.gca().invert_yaxis()
+pl.xticks(())
+#pl.ylim((0,n))
+#pl.axis('off')
+
+pl.subplot(gs[1:,1:],sharex=ax1,sharey=ax2)
+pl.imshow(G0,interpolation='nearest')
+
+pl.xlim((0,n))
+#pl.ylim((0,n))
+#pl.axis('off') \ No newline at end of file
diff --git a/ot/__init__.py b/ot/__init__.py
index d0ab3f7..beeae7f 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -1,6 +1,12 @@
+
from emd import emd
+from bregman import sinkhorn
+
+import utils
+import datasets
+from utils import dist,dots
__all__ = ["emd"]
diff --git a/ot/bregman/__init__.py b/ot/bregman/__init__.py
new file mode 100644
index 0000000..e4016ea
--- /dev/null
+++ b/ot/bregman/__init__.py
@@ -0,0 +1,4 @@
+
+
+from .sink import sinkhorn
+
diff --git a/ot/bregman/sinkhorn.py b/ot/bregman/sinkhorn.py
new file mode 100644
index 0000000..798ac97
--- /dev/null
+++ b/ot/bregman/sinkhorn.py
@@ -0,0 +1,91 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Fri Oct 21 09:40:21 2016
+
+@author: rflamary
+"""
+
+import numpy as np
+
+
+def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
+ """
+ Solve the optimal transport problem (OT)
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the metric cost matrix
+ - Omega is the entropic regularization term
+ - a and b are the sample weights
+
+ Parameters
+ ----------
+ a : (ns,) ndarray
+ samples in the source domain
+ b : (nt,) ndarray
+ samples in the target domain
+ M : (ns,nt) ndarray
+ loss matrix
+ reg: float()
+ Regularization term >0
+
+
+ Returns
+ -------
+ gamma: (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+
+ """
+ # init data
+ Nini = len(a)
+ Nfin = len(b)
+
+
+ cpt = 0
+
+ # we assume that no distances are null except those of the diagonal of distances
+ u = np.ones(Nini)/Nini
+ v = np.ones(Nfin)/Nfin
+ uprev=np.zeros(Nini)
+ vprev=np.zeros(Nini)
+
+ #print reg
+
+ K = np.exp(-reg*M)
+ #print np.min(K)
+
+ Kp = np.dot(np.diag(1/a),K)
+ transp = K
+ cpt = 0
+ err=1
+ while (err>stopThr and cpt<numItermax):
+ if np.any(np.dot(K.T,u)==0) or np.any(np.isnan(u)) or np.any(np.isnan(v)):
+ # we have reached the machine precision
+ # come back to previous solution and quit loop
+ print('Warning: numerical errrors')
+ if cpt!=0:
+ u = uprev
+ v = vprev
+ break
+ uprev = u
+ vprev = v
+ v = np.divide(b,np.dot(K.T,u))
+ u = 1./np.dot(Kp,v)
+ if cpt%10==0:
+ # we can speed up the process by checking for the error only all the 10th iterations
+ transp = np.dot(np.diag(u),np.dot(K,np.diag(v)))
+ err = np.linalg.norm((np.sum(transp,axis=0)-b))**2
+ cpt = cpt +1
+ #print 'err=',err,' cpt=',cpt
+
+ return np.dot(np.diag(u),np.dot(K,np.diag(v)))
+
+
diff --git a/ot/datasets.py b/ot/datasets.py
new file mode 100644
index 0000000..bb10ba4
--- /dev/null
+++ b/ot/datasets.py
@@ -0,0 +1,10 @@
+
+import numpy as np
+
+
+
+def get_1D_gauss(n,m,s):
+ "return a 1D histogram for a gaussian distribution (n bins, mean m and std s) "
+ x=np.arange(n,dtype=np.float64)
+ h=np.exp(-(x-m)**2/(2*s^2))
+ return h/h.sum() \ No newline at end of file
diff --git a/ot/emd/emd.pyx b/ot/emd/emd.pyx
index d090cea..e5ac8e0 100644
--- a/ot/emd/emd.pyx
+++ b/ot/emd/emd.pyx
@@ -19,6 +19,22 @@ cdef extern from "EMD.h":
@cython.boundscheck(False)
@cython.wraparound(False)
def emd( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M):
+ """
+ Solves the Earth Movers distance problem and returns the optimal transport matrix
+
+
+
+ :param a: m weights of the source distribution (must sum to one)
+ :param b: n weights of the target distribution (must sum to one)
+ :param M: m x n cost matrix
+ :type a: np.ndarray
+ :type b: np.ndarray
+ :type M: np.ndarray
+ :return: Optimal transport matrix
+ :rtype: np.ndarray
+
+
+ """
cdef int n1= M.shape[0]
cdef int n2= M.shape[1]
diff --git a/ot/utils.py b/ot/utils.py
new file mode 100644
index 0000000..1a1c6b8
--- /dev/null
+++ b/ot/utils.py
@@ -0,0 +1,15 @@
+
+import numpy as np
+from scipy.spatial.distance import cdist, pdist
+
+
+def dist(x1,x2=None,metric='sqeuclidean'):
+ """Compute distance between samples in x1 and x2"""
+ if x2 is None:
+ return pdist(x1,metric=metric)
+ else:
+ return cdist(x1,x2,metric=metric)
+
+def dots(*args):
+ """ Stupid but nice dots function for multiple matrix multiply """
+ return reduce(np.dot,args) \ No newline at end of file
diff --git a/setup.py b/setup.py
index 3e31742..ddedb36 100755
--- a/setup.py
+++ b/setup.py
@@ -2,11 +2,11 @@
from distutils.core import setup, Extension
import numpy
-from Cython.Distutils import build_ext
+#from Cython.Distutils import build_ext
from Cython.Build import cythonize
import os
-import glob
+#import glob
version=0.1