diff options
-rw-r--r-- | ot/lp/__init__.py | 77 | ||||
-rw-r--r-- | test/test_ot.py | 8 |
2 files changed, 79 insertions, 6 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index decff29..e9635a1 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -313,10 +313,83 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None return X -def emd_1d(a, b, x_a, x_b, metric='sqeuclidean', dense=True, log=False): +def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False): """Solves the Earth Movers distance problem between 1d measures and returns the OT matrix + + .. math:: + \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j]) + + s.t. \gamma 1 = a + \gamma^T 1= b + \gamma\geq 0 + where : + + - d is the metric + - x_a and x_b are the samples + - a and b are the sample weights + + Uses the algorithm proposed in [1]_ + + Parameters + ---------- + x_a : (ns,) or (ns, 1) ndarray, float64 + Source histogram (uniform weight if empty list) + x_b : (nt,) or (ns, 1) ndarray, float64 + Target histogram (uniform weight if empty list) + a : (ns,) ndarray, float64 + Source histogram (uniform weight if empty list) + b : (nt,) ndarray, float64 + Target histogram (uniform weight if empty list) + dense: boolean, optional (default=True) + If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). + Otherwise returns a sparse representation using scipy's `coo_matrix` + format. + Due to implementation details, this function runs faster when + dense is set to False. + metric: str, optional (default='sqeuclidean') + Metric to be used. Has to be a string. + Due to implementation details, this function runs faster when + `'sqeuclidean'` or `'euclidean'` metrics are used. + log: boolean, optional (default=False) + If True, returns a dictionary containing the cost. + Otherwise returns only the optimal transportation matrix. + + Returns + ------- + gamma: (ns, nt) ndarray + Optimal transportation matrix for the given parameters + log: dict + If input log is True, a dictionary containing the cost + + + Examples + -------- + + Simple example with obvious solution. The function emd_1d accepts lists and + perform automatic conversion to numpy arrays + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> x_a = [0., 2.] + >>> x_b = [0., 3.] + >>> ot.emd_1d(a, b, x_a, x_b) + array([[ 0.5, 0. ], + [ 0. , 0.5]]) + + References + ---------- + + .. [1] TODO + + See Also + -------- + ot.lp.emd : EMD for multidimensional distributions + ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the + transportation matrix) + """ a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) @@ -353,7 +426,7 @@ def emd_1d(a, b, x_a, x_b, metric='sqeuclidean', dense=True, log=False): return G -def emd2_1d(a, b, x_a, x_b, metric='sqeuclidean', dense=True, log=False): +def emd2_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False): """Solves the Earth Movers distance problem between 1d measures and returns the loss diff --git a/test/test_ot.py b/test/test_ot.py index 2a2e0a5..6d6ea26 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -59,10 +59,10 @@ def test_emd_1d_emd2_1d(): G, log = ot.emd([], [], M, log=True) wass = log["cost"] - G_1d, log = ot.emd_1d([], [], u, v, metric='sqeuclidean', log=True) + G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True) wass1d = log["cost"] - wass1d_emd2 = ot.emd2_1d([], [], u, v, metric='sqeuclidean', log=False) - wass1d_euc = ot.emd2_1d([], [], u, v, metric='euclidean', log=False) + wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False) + wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False) # check loss is similar np.testing.assert_allclose(wass, wass1d) @@ -82,7 +82,7 @@ def test_emd_1d_emd2_1d(): # check AssertionError is raised if called on non 1d arrays u = np.random.randn(n, 2) v = np.random.randn(m, 2) - np.testing.assert_raises(AssertionError, ot.emd_1d, [], [], u, v) + np.testing.assert_raises(AssertionError, ot.emd_1d, u, v, [], []) def test_emd_empty(): |