diff options
Diffstat (limited to 'ot')
-rw-r--r-- | ot/lp/__init__.py | 114 |
1 files changed, 92 insertions, 22 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index e9635a1..a350d60 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -321,8 +321,8 @@ def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False): .. math:: \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j]) - s.t. \gamma 1 = a - \gamma^T 1= b + s.t. \gamma 1 = a, + \gamma^T 1= b, \gamma\geq 0 where : @@ -330,28 +330,27 @@ def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False): - x_a and x_b are the samples - a and b are the sample weights - Uses the algorithm proposed in [1]_ + Uses the algorithm detailed in [1]_ Parameters ---------- x_a : (ns,) or (ns, 1) ndarray, float64 - Source histogram (uniform weight if empty list) + Source dirac locations (on the real line) x_b : (nt,) or (ns, 1) ndarray, float64 - Target histogram (uniform weight if empty list) + Target dirac locations (on the real line) a : (ns,) ndarray, float64 Source histogram (uniform weight if empty list) b : (nt,) ndarray, float64 Target histogram (uniform weight if empty list) + metric: str, optional (default='sqeuclidean') + Metric to be used. Only strings listed in ... are accepted. + Due to implementation details, this function runs faster when + `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used. dense: boolean, optional (default=True) If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). Otherwise returns a sparse representation using scipy's `coo_matrix` - format. - Due to implementation details, this function runs faster when + format. Due to implementation details, this function runs faster when dense is set to False. - metric: str, optional (default='sqeuclidean') - Metric to be used. Has to be a string. - Due to implementation details, this function runs faster when - `'sqeuclidean'` or `'euclidean'` metrics are used. log: boolean, optional (default=False) If True, returns a dictionary containing the cost. Otherwise returns only the optimal transportation matrix. @@ -368,28 +367,28 @@ def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False): -------- Simple example with obvious solution. The function emd_1d accepts lists and - perform automatic conversion to numpy arrays + performs automatic conversion to numpy arrays >>> import ot >>> a=[.5, .5] >>> b=[.5, .5] - >>> x_a = [0., 2.] + >>> x_a = [2., 0.] >>> x_b = [0., 3.] - >>> ot.emd_1d(a, b, x_a, x_b) - array([[ 0.5, 0. ], - [ 0. , 0.5]]) + >>> ot.emd_1d(x_a, x_b, a, b) + array([[0. , 0.5], + [0.5, 0. ]]) References ---------- - .. [1] TODO + .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + Transport", 2018. See Also -------- ot.lp.emd : EMD for multidimensional distributions ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the transportation matrix) - """ a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) @@ -418,10 +417,9 @@ def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False): G = coo_matrix((G_sorted, (perm_a[indices[:, 0]], perm_b[indices[:, 1]])), shape=(a.shape[0], b.shape[0])) if dense: - G = G.todense() + G = G.toarray() if log: - log = {} - log['cost'] = cost + log = {'cost': cost} return G, log return G @@ -430,10 +428,82 @@ def emd2_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False): """Solves the Earth Movers distance problem between 1d measures and returns the loss + + .. math:: + \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j]) + + s.t. \gamma 1 = a, + \gamma^T 1= b, + \gamma\geq 0 + where : + + - d is the metric + - x_a and x_b are the samples + - a and b are the sample weights + + Uses the algorithm detailed in [1]_ + + Parameters + ---------- + x_a : (ns,) or (ns, 1) ndarray, float64 + Source dirac locations (on the real line) + x_b : (nt,) or (ns, 1) ndarray, float64 + Target dirac locations (on the real line) + a : (ns,) ndarray, float64 + Source histogram (uniform weight if empty list) + b : (nt,) ndarray, float64 + Target histogram (uniform weight if empty list) + metric: str, optional (default='sqeuclidean') + Metric to be used. Only strings listed in ... are accepted. + Due to implementation details, this function runs faster when + `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used. + dense: boolean, optional (default=True) + If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). + Otherwise returns a sparse representation using scipy's `coo_matrix` + format. Only used if log is set to True. Due to implementation details, + this function runs faster when dense is set to False. + log: boolean, optional (default=False) + If True, returns a dictionary containing the transportation matrix. + Otherwise returns only the loss. + + Returns + ------- + loss: float + Cost associated to the optimal transportation + log: dict + If input log is True, a dictionary containing the Optimal transportation + matrix for the given parameters + + + Examples + -------- + + Simple example with obvious solution. The function emd2_1d accepts lists and + performs automatic conversion to numpy arrays + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> x_a = [2., 0.] + >>> x_b = [0., 3.] + >>> ot.emd2_1d(x_a, x_b, a, b) + 0.5 + + References + ---------- + + .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + Transport", 2018. + + See Also + -------- + ot.lp.emd2 : EMD for multidimensional distributions + ot.lp.emd_1d : EMD for 1d distributions (returns the transportation matrix + instead of the cost) """ # If we do not return G (log==False), then we should not to cast it to dense # (useless overhead) - G, log_emd = emd_1d(a=a, b=b, x_a=x_a, x_b=x_b, metric=metric, + G, log_emd = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric=metric, dense=dense and log, log=True) cost = log_emd['cost'] if log: |