summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-06-20 14:51:12 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-06-20 14:51:12 +0200
commit77bcf836272faf1a40fdea97131b85390e935be3 (patch)
tree4ab08952b9f67183f30d6a01af59c654075aaeb7
parent2bcc24aa05078cfbf160be06fc9ad166bee52904 (diff)
add clean zeros function for sparse distributions
-rw-r--r--ot/bregman.py2
-rw-r--r--ot/utils.py7
2 files changed, 9 insertions, 0 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index e847f24..a13345d 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -108,6 +108,8 @@ def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ver
return sink()
+
+
def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs):
"""
Solve the entropic regularization optimal transport problem and return the OT matrix
diff --git a/ot/utils.py b/ot/utils.py
index fc6b0d2..7ad7637 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -50,6 +50,13 @@ def unif(n):
"""
return np.ones((n,))/n
+def clean_zeros(a,b,M):
+ """ Remove all components with zeros weights in a and b
+ """
+ M2=M[a>0,:][:,b>0].copy() # copy force c style matrix (froemd)
+ a2=a[a>0]
+ b2=b[b>0]
+ return a2,b2,M2
def dist(x1,x2=None,metric='sqeuclidean'):
"""Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist