summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-10-24 15:51:31 +0200
committerRémi Flamary <remi.flamary@gmail.com>2016-10-24 15:51:31 +0200
commitd130b55fd5845bf0848bb02cebc58ce1ae89f8a3 (patch)
treefc037c5ed6fd73ef1a00fceaa2ae36a5b430791c
parent6ee839d64d8b0f5f73fd5899032f2ae4bd8a7a51 (diff)
bregman as module
-rw-r--r--ot/__init__.py2
-rw-r--r--ot/bregman/__init__.py4
-rw-r--r--ot/bregman/sink.py91
-rwxr-xr-xsetup.py2
4 files changed, 2 insertions, 97 deletions
diff --git a/ot/__init__.py b/ot/__init__.py
index 0a9e89b..f5490f7 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -14,4 +14,4 @@ from bregman import sinkhorn
# utils functions
from utils import dist,dots,unif
-__all__ = ["emd","sinkhorn","utils",'datasets','plot','dist','dots']
+__all__ = ["emd","sinkhorn","utils",'datasets','bregman','plot','dist','dots']
diff --git a/ot/bregman/__init__.py b/ot/bregman/__init__.py
deleted file mode 100644
index e4016ea..0000000
--- a/ot/bregman/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-
-
-from .sink import sinkhorn
-
diff --git a/ot/bregman/sink.py b/ot/bregman/sink.py
deleted file mode 100644
index 8b97e1e..0000000
--- a/ot/bregman/sink.py
+++ /dev/null
@@ -1,91 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-Created on Fri Oct 21 09:40:21 2016
-
-@author: rflamary
-"""
-
-import numpy as np
-
-
-def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
- """
- Solve the optimal transport problem (OT)
-
- .. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
-
- s.t. \gamma 1 = a
-
- \gamma^T 1= b
-
- \gamma\geq 0
- where :
-
- - M is the metric cost matrix
- - Omega is the entropic regularization term
- - a and b are the sample weights
-
- Parameters
- ----------
- a : (ns,) ndarray
- samples in the source domain
- b : (nt,) ndarray
- samples in the target domain
- M : (ns,nt) ndarray
- loss matrix
- reg: float()
- Regularization term >0
-
-
- Returns
- -------
- gamma: (ns x nt) ndarray
- Optimal transportation matrix for the given parameters
-
- """
- # init data
- Nini = len(a)
- Nfin = len(b)
-
-
- cpt = 0
-
- # we assume that no distances are null except those of the diagonal of distances
- u = np.ones(Nini)/Nini
- v = np.ones(Nfin)/Nfin
- uprev=np.zeros(Nini)
- vprev=np.zeros(Nini)
-
- #print reg
-
- K = np.exp(-M/reg)
- #print np.min(K)
-
- Kp = np.dot(np.diag(1/a),K)
- transp = K
- cpt = 0
- err=1
- while (err>stopThr and cpt<numItermax):
- if np.any(np.dot(K.T,u)==0) or np.any(np.isnan(u)) or np.any(np.isnan(v)):
- # we have reached the machine precision
- # come back to previous solution and quit loop
- print('Warning: numerical errrors')
- if cpt!=0:
- u = uprev
- v = vprev
- break
- uprev = u
- vprev = v
- v = np.divide(b,np.dot(K.T,u))
- u = 1./np.dot(Kp,v)
- if cpt%10==0:
- # we can speed up the process by checking for the error only all the 10th iterations
- transp = np.dot(np.diag(u),np.dot(K,np.diag(v)))
- err = np.linalg.norm((np.sum(transp,axis=0)-b))**2
- cpt = cpt +1
- #print 'err=',err,' cpt=',cpt
-
- return np.dot(np.diag(u),np.dot(K,np.diag(v)))
-
-
diff --git a/setup.py b/setup.py
index c7b0da0..c20c3f5 100755
--- a/setup.py
+++ b/setup.py
@@ -20,7 +20,7 @@ setup(name='python-webgen',
author=u'Remi Flamary',
author_email='remi.flamary@gmail.com',
url='https://github.com/rflamary/POT',
- packages=['ot','ot.emd','ot.bregman'],
+ packages=['ot','ot.emd'],
ext_modules = cythonize(Extension(
"ot.emd.emd", # the extesion name
sources=["ot/emd/emd.pyx", "ot/emd/EMD_wrap.cpp"], # the Cython source and