diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2021-10-27 08:41:08 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-10-27 08:41:08 +0200 |
commit | d7554331fc409fea48ee758fd630909dd9dc4827 (patch) | |
tree | 9b8ed4bf94c12d034d5fb1de5b7b5b76c23b4d05 /ot/utils.py | |
parent | 76450dddf8dd62b9714b72e99ae075516246d433 (diff) |
[WIP] Sinkhorn in log space (#290)
* adda sinkhorn log and working sinkhorn2 function
* more tests pass
* more tests pass
* it works but not by default yet
* remove warningd
* update circleci doc
* update circleci doc
* new sinkhorn implemeted but not by default
* better
* doctest pass
* test doctest
* new test utils
* remove pep8 errors
* remove pep8 errors
* doc new implementtaion with log
* test sinkhorn 2
* doc for log implementation
Diffstat (limited to 'ot/utils.py')
-rw-r--r-- | ot/utils.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/ot/utils.py b/ot/utils.py index 6a782e6..0608aee 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -183,7 +183,7 @@ def euclidean_distances(X, Y, squared=False): return c -def dist(x1, x2=None, metric='sqeuclidean'): +def dist(x1, x2=None, metric='sqeuclidean', p=2): """Compute distance between samples in x1 and x2 .. note:: This function is backend-compatible and will work on arrays @@ -222,7 +222,7 @@ def dist(x1, x2=None, metric='sqeuclidean'): if not get_backend(x1, x2).__name__ == 'numpy': raise NotImplementedError() else: - return cdist(x1, x2, metric=metric) + return cdist(x1, x2, metric=metric, p=p) def dist0(n, method='lin_square'): |