diff options
author | ngayraud <nat.gayraud@gmail.com> | 2019-08-12 16:37:58 -0400 |
---|---|---|
committer | ngayraud <nat.gayraud@gmail.com> | 2019-08-12 16:37:58 -0400 |
commit | 9d4b786a036ac95989825beec819521089fb4feb (patch) | |
tree | 3d9fcc4fd26e5d8dbe100d79eddf0776801df33a /ot | |
parent | 092866815cf906012f9194b87af1e7ae0270f7e7 (diff) |
fixes for travis, added test, minor nits
Diffstat (limited to 'ot')
-rw-r--r-- | ot/da.py | 2 | ||||
-rw-r--r-- | ot/utils.py | 4 |
2 files changed, 4 insertions, 2 deletions
@@ -1852,7 +1852,7 @@ class UnbalancedSinkhornTransport(BaseTransport): """ def __init__(self, reg_e=1., reg_m=0.1, method='sinkhorn', - max_iter=10, tol=10e-9, verbose=False, log=False, + max_iter=10, tol=1e-9, verbose=False, log=False, metric="sqeuclidean", norm=None, distribution_estimation=distribution_estimation_uniform, out_of_sample_map='ferradans', limit_max=10): diff --git a/ot/utils.py b/ot/utils.py index be839f8..a334fea 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -178,7 +178,9 @@ def cost_normalization(C, norm=None): The input cost matrix normalized according to given norm. """ - if norm == "median": + if norm is None: + pass + elif norm == "median": C /= float(np.median(C)) elif norm == "max": C /= float(np.max(C)) |