diff options
Diffstat (limited to 'ot/utils.py')
-rw-r--r-- | ot/utils.py | 9 |
1 files changed, 7 insertions, 2 deletions
diff --git a/ot/utils.py b/ot/utils.py index 8419c83..d4127e3 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -178,7 +178,9 @@ def cost_normalization(C, norm=None): The input cost matrix normalized according to given norm. """ - if norm == "median": + if norm is None: + pass + elif norm == "median": C /= float(np.median(C)) elif norm == "max": C /= float(np.max(C)) @@ -186,7 +188,10 @@ def cost_normalization(C, norm=None): C = np.log(1 + C) elif norm == "loglog": C = np.log1p(np.log1p(C)) - + else: + raise ValueError('Norm %s is not a valid option.\n' + 'Valid options are:\n' + 'median, max, log, loglog' % norm) return C |