From 0a039eb07a3ca9ae3c5635cca1719428f62bf67d Mon Sep 17 00:00:00 2001 From: Romain Tavenard Date: Mon, 24 Jun 2019 13:15:38 +0200 Subject: Made weight vectors optional to match scipy's wass1d API --- ot/lp/__init__.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) (limited to 'ot/lp/__init__.py') diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 645ed8b..bf218d3 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -313,7 +313,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None return X -def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False): +def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', dense=True, log=False): """Solves the Earth Movers distance problem between 1d measures and returns the OT matrix @@ -338,10 +338,10 @@ def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False): Source dirac locations (on the real line) x_b : (nt,) or (ns, 1) ndarray, float64 Target dirac locations (on the real line) - a : (ns,) ndarray, float64 - Source histogram (uniform weight if empty list) - b : (nt,) ndarray, float64 - Target histogram (uniform weight if empty list) + a : (ns,) ndarray, float64, optional + Source histogram (default is uniform weight) + b : (nt,) ndarray, float64, optional + Target histogram (default is uniform weight) metric: str, optional (default='sqeuclidean') Metric to be used. Only strings listed in :func:`ot.dist` are accepted. Due to implementation details, this function runs faster when @@ -375,6 +375,9 @@ def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False): >>> x_a = [2., 0.] >>> x_b = [0., 3.] >>> ot.emd_1d(x_a, x_b, a, b) + array([[0. , 0.5], + [0.5, 0. ]]) + >>> ot.emd_1d(x_a, x_b) array([[0. , 0.5], [0.5, 0. ]]) @@ -401,9 +404,9 @@ def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False): "emd_1d should only be used with monodimensional data" # if empty array given then use uniform distributions - if len(a) == 0: + if a.ndim == 0 or len(a) == 0: a = np.ones((x_a.shape[0],), dtype=np.float64) / x_a.shape[0] - if len(b) == 0: + if b.ndim == 0 or len(b) == 0: b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0] x_a_1d = x_a.reshape((-1, )) @@ -424,7 +427,7 @@ def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False): return G -def emd2_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False): +def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', dense=True, log=False): """Solves the Earth Movers distance problem between 1d measures and returns the loss @@ -449,10 +452,10 @@ def emd2_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False): Source dirac locations (on the real line) x_b : (nt,) or (ns, 1) ndarray, float64 Target dirac locations (on the real line) - a : (ns,) ndarray, float64 - Source histogram (uniform weight if empty list) - b : (nt,) ndarray, float64 - Target histogram (uniform weight if empty list) + a : (ns,) ndarray, float64, optional + Source histogram (default is uniform weight) + b : (nt,) ndarray, float64, optional + Target histogram (default is uniform weight) metric: str, optional (default='sqeuclidean') Metric to be used. Only strings listed in :func:`ot.dist` are accepted. Due to implementation details, this function runs faster when @@ -488,6 +491,8 @@ def emd2_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False): >>> x_b = [0., 3.] >>> ot.emd2_1d(x_a, x_b, a, b) 0.5 + >>> ot.emd2_1d(x_a, x_b) + 0.5 References ---------- -- cgit v1.2.3