summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-10-21 10:51:27 +0200
committerRémi Flamary <remi.flamary@gmail.com>2016-10-21 10:51:27 +0200
commit581c6de782dca279edd97778cc474e7597788c0f (patch)
tree760161e1c7812d8caf77bf8acc543453c6213e39 /ot
parent2109443f5bea396114d1f9e0563ba5c396378c57 (diff)
demo+sinkhorn
Diffstat (limited to 'ot')
-rw-r--r--ot/__init__.py6
-rw-r--r--ot/bregman/__init__.py4
-rw-r--r--ot/bregman/sinkhorn.py91
-rw-r--r--ot/datasets.py10
-rw-r--r--ot/emd/emd.pyx16
-rw-r--r--ot/utils.py15
6 files changed, 142 insertions, 0 deletions
diff --git a/ot/__init__.py b/ot/__init__.py
index d0ab3f7..beeae7f 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -1,6 +1,12 @@
+
from emd import emd
+from bregman import sinkhorn
+
+import utils
+import datasets
+from utils import dist,dots
__all__ = ["emd"]
diff --git a/ot/bregman/__init__.py b/ot/bregman/__init__.py
new file mode 100644
index 0000000..e4016ea
--- /dev/null
+++ b/ot/bregman/__init__.py
@@ -0,0 +1,4 @@
+
+
+from .sink import sinkhorn
+
diff --git a/ot/bregman/sinkhorn.py b/ot/bregman/sinkhorn.py
new file mode 100644
index 0000000..798ac97
--- /dev/null
+++ b/ot/bregman/sinkhorn.py
@@ -0,0 +1,91 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Fri Oct 21 09:40:21 2016
+
+@author: rflamary
+"""
+
+import numpy as np
+
+
+def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
+ """
+ Solve the optimal transport problem (OT)
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the metric cost matrix
+ - Omega is the entropic regularization term
+ - a and b are the sample weights
+
+ Parameters
+ ----------
+ a : (ns,) ndarray
+ samples in the source domain
+ b : (nt,) ndarray
+ samples in the target domain
+ M : (ns,nt) ndarray
+ loss matrix
+ reg: float()
+ Regularization term >0
+
+
+ Returns
+ -------
+ gamma: (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+
+ """
+ # init data
+ Nini = len(a)
+ Nfin = len(b)
+
+
+ cpt = 0
+
+ # we assume that no distances are null except those of the diagonal of distances
+ u = np.ones(Nini)/Nini
+ v = np.ones(Nfin)/Nfin
+ uprev=np.zeros(Nini)
+ vprev=np.zeros(Nini)
+
+ #print reg
+
+ K = np.exp(-reg*M)
+ #print np.min(K)
+
+ Kp = np.dot(np.diag(1/a),K)
+ transp = K
+ cpt = 0
+ err=1
+ while (err>stopThr and cpt<numItermax):
+ if np.any(np.dot(K.T,u)==0) or np.any(np.isnan(u)) or np.any(np.isnan(v)):
+ # we have reached the machine precision
+ # come back to previous solution and quit loop
+ print('Warning: numerical errrors')
+ if cpt!=0:
+ u = uprev
+ v = vprev
+ break
+ uprev = u
+ vprev = v
+ v = np.divide(b,np.dot(K.T,u))
+ u = 1./np.dot(Kp,v)
+ if cpt%10==0:
+ # we can speed up the process by checking for the error only all the 10th iterations
+ transp = np.dot(np.diag(u),np.dot(K,np.diag(v)))
+ err = np.linalg.norm((np.sum(transp,axis=0)-b))**2
+ cpt = cpt +1
+ #print 'err=',err,' cpt=',cpt
+
+ return np.dot(np.diag(u),np.dot(K,np.diag(v)))
+
+
diff --git a/ot/datasets.py b/ot/datasets.py
new file mode 100644
index 0000000..bb10ba4
--- /dev/null
+++ b/ot/datasets.py
@@ -0,0 +1,10 @@
+
+import numpy as np
+
+
+
+def get_1D_gauss(n,m,s):
+ "return a 1D histogram for a gaussian distribution (n bins, mean m and std s) "
+ x=np.arange(n,dtype=np.float64)
+ h=np.exp(-(x-m)**2/(2*s^2))
+ return h/h.sum() \ No newline at end of file
diff --git a/ot/emd/emd.pyx b/ot/emd/emd.pyx
index d090cea..e5ac8e0 100644
--- a/ot/emd/emd.pyx
+++ b/ot/emd/emd.pyx
@@ -19,6 +19,22 @@ cdef extern from "EMD.h":
@cython.boundscheck(False)
@cython.wraparound(False)
def emd( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M):
+ """
+ Solves the Earth Movers distance problem and returns the optimal transport matrix
+
+
+
+ :param a: m weights of the source distribution (must sum to one)
+ :param b: n weights of the target distribution (must sum to one)
+ :param M: m x n cost matrix
+ :type a: np.ndarray
+ :type b: np.ndarray
+ :type M: np.ndarray
+ :return: Optimal transport matrix
+ :rtype: np.ndarray
+
+
+ """
cdef int n1= M.shape[0]
cdef int n2= M.shape[1]
diff --git a/ot/utils.py b/ot/utils.py
new file mode 100644
index 0000000..1a1c6b8
--- /dev/null
+++ b/ot/utils.py
@@ -0,0 +1,15 @@
+
+import numpy as np
+from scipy.spatial.distance import cdist, pdist
+
+
+def dist(x1,x2=None,metric='sqeuclidean'):
+ """Compute distance between samples in x1 and x2"""
+ if x2 is None:
+ return pdist(x1,metric=metric)
+ else:
+ return cdist(x1,x2,metric=metric)
+
+def dots(*args):
+ """ Stupid but nice dots function for multiple matrix multiply """
+ return reduce(np.dot,args) \ No newline at end of file