summaryrefslogtreecommitdiff
path: root/ot/lp/__init__.py
diff options
context:
space:
mode:
authorRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-06-24 13:15:38 +0200
committerRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-06-24 13:15:38 +0200
commit0a039eb07a3ca9ae3c5635cca1719428f62bf67d (patch)
treee35061dc69f102e3e9f9f28a5d81c07666fa54d2 /ot/lp/__init__.py
parent77452dd92f607c3f18a6420cb8cd09fa5cd905a6 (diff)
Made weight vectors optional to match scipy's wass1d API
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r--ot/lp/__init__.py29
1 files changed, 17 insertions, 12 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 645ed8b..bf218d3 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -313,7 +313,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
return X
-def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
+def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', dense=True, log=False):
"""Solves the Earth Movers distance problem between 1d measures and returns
the OT matrix
@@ -338,10 +338,10 @@ def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
Source dirac locations (on the real line)
x_b : (nt,) or (ns, 1) ndarray, float64
Target dirac locations (on the real line)
- a : (ns,) ndarray, float64
- Source histogram (uniform weight if empty list)
- b : (nt,) ndarray, float64
- Target histogram (uniform weight if empty list)
+ a : (ns,) ndarray, float64, optional
+ Source histogram (default is uniform weight)
+ b : (nt,) ndarray, float64, optional
+ Target histogram (default is uniform weight)
metric: str, optional (default='sqeuclidean')
Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
Due to implementation details, this function runs faster when
@@ -377,6 +377,9 @@ def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
>>> ot.emd_1d(x_a, x_b, a, b)
array([[0. , 0.5],
[0.5, 0. ]])
+ >>> ot.emd_1d(x_a, x_b)
+ array([[0. , 0.5],
+ [0.5, 0. ]])
References
----------
@@ -401,9 +404,9 @@ def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
"emd_1d should only be used with monodimensional data"
# if empty array given then use uniform distributions
- if len(a) == 0:
+ if a.ndim == 0 or len(a) == 0:
a = np.ones((x_a.shape[0],), dtype=np.float64) / x_a.shape[0]
- if len(b) == 0:
+ if b.ndim == 0 or len(b) == 0:
b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0]
x_a_1d = x_a.reshape((-1, ))
@@ -424,7 +427,7 @@ def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
return G
-def emd2_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
+def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', dense=True, log=False):
"""Solves the Earth Movers distance problem between 1d measures and returns
the loss
@@ -449,10 +452,10 @@ def emd2_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
Source dirac locations (on the real line)
x_b : (nt,) or (ns, 1) ndarray, float64
Target dirac locations (on the real line)
- a : (ns,) ndarray, float64
- Source histogram (uniform weight if empty list)
- b : (nt,) ndarray, float64
- Target histogram (uniform weight if empty list)
+ a : (ns,) ndarray, float64, optional
+ Source histogram (default is uniform weight)
+ b : (nt,) ndarray, float64, optional
+ Target histogram (default is uniform weight)
metric: str, optional (default='sqeuclidean')
Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
Due to implementation details, this function runs faster when
@@ -488,6 +491,8 @@ def emd2_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
>>> x_b = [0., 3.]
>>> ot.emd2_1d(x_a, x_b, a, b)
0.5
+ >>> ot.emd2_1d(x_a, x_b)
+ 0.5
References
----------