From d130b55fd5845bf0848bb02cebc58ce1ae89f8a3 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Mon, 24 Oct 2016 15:51:31 +0200 Subject: bregman as module --- ot/__init__.py | 2 +- ot/bregman/__init__.py | 4 --- ot/bregman/sink.py | 91 -------------------------------------------------- 3 files changed, 1 insertion(+), 96 deletions(-) delete mode 100644 ot/bregman/__init__.py delete mode 100644 ot/bregman/sink.py (limited to 'ot') diff --git a/ot/__init__.py b/ot/__init__.py index 0a9e89b..f5490f7 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -14,4 +14,4 @@ from bregman import sinkhorn # utils functions from utils import dist,dots,unif -__all__ = ["emd","sinkhorn","utils",'datasets','plot','dist','dots'] +__all__ = ["emd","sinkhorn","utils",'datasets','bregman','plot','dist','dots'] diff --git a/ot/bregman/__init__.py b/ot/bregman/__init__.py deleted file mode 100644 index e4016ea..0000000 --- a/ot/bregman/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ - - -from .sink import sinkhorn - diff --git a/ot/bregman/sink.py b/ot/bregman/sink.py deleted file mode 100644 index 8b97e1e..0000000 --- a/ot/bregman/sink.py +++ /dev/null @@ -1,91 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Created on Fri Oct 21 09:40:21 2016 - -@author: rflamary -""" - -import numpy as np - - -def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9): - """ - Solve the optimal transport problem (OT) - - .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) - - s.t. \gamma 1 = a - - \gamma^T 1= b - - \gamma\geq 0 - where : - - - M is the metric cost matrix - - Omega is the entropic regularization term - - a and b are the sample weights - - Parameters - ---------- - a : (ns,) ndarray - samples in the source domain - b : (nt,) ndarray - samples in the target domain - M : (ns,nt) ndarray - loss matrix - reg: float() - Regularization term >0 - - - Returns - ------- - gamma: (ns x nt) ndarray - Optimal transportation matrix for the given parameters - - """ - # init data - Nini = len(a) - Nfin = len(b) - - - cpt = 0 - - # we assume that no distances are null except those of the diagonal of distances - u = np.ones(Nini)/Nini - v = np.ones(Nfin)/Nfin - uprev=np.zeros(Nini) - vprev=np.zeros(Nini) - - #print reg - - K = np.exp(-M/reg) - #print np.min(K) - - Kp = np.dot(np.diag(1/a),K) - transp = K - cpt = 0 - err=1 - while (err>stopThr and cpt