summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
Diffstat (limited to 'ot')
-rw-r--r--ot/da.py2
-rw-r--r--ot/utils.py4
2 files changed, 4 insertions, 2 deletions
diff --git a/ot/da.py b/ot/da.py
index c1d9849..2af855d 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -1852,7 +1852,7 @@ class UnbalancedSinkhornTransport(BaseTransport):
"""
def __init__(self, reg_e=1., reg_m=0.1, method='sinkhorn',
- max_iter=10, tol=10e-9, verbose=False, log=False,
+ max_iter=10, tol=1e-9, verbose=False, log=False,
metric="sqeuclidean", norm=None,
distribution_estimation=distribution_estimation_uniform,
out_of_sample_map='ferradans', limit_max=10):
diff --git a/ot/utils.py b/ot/utils.py
index be839f8..a334fea 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -178,7 +178,9 @@ def cost_normalization(C, norm=None):
The input cost matrix normalized according to given norm.
"""
- if norm == "median":
+ if norm is None:
+ pass
+ elif norm == "median":
C /= float(np.median(C))
elif norm == "max":
C /= float(np.max(C))