diff options
author | aje <leo_g_autheron@hotmail.fr> | 2017-08-30 10:53:31 +0200 |
---|---|---|
committer | aje <leo_g_autheron@hotmail.fr> | 2017-08-30 10:53:31 +0200 |
commit | 0316d552fa6005aaf0f6231eb9ca20441d5a2532 (patch) | |
tree | 5faf38ccb37e8476bd9b5afd807fa3de69d7d6c8 /ot/utils.py | |
parent | 5bbea9c608e66b499f9858af408cc65c07cf4ac2 (diff) |
Move normalize function in utils.py
Diffstat (limited to 'ot/utils.py')
-rw-r--r-- | ot/utils.py | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/ot/utils.py b/ot/utils.py index 01f2a67..31a002b 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -134,6 +134,39 @@ def dist0(n, method='lin_square'): return res +def cost_normalization(C, norm=None): + """ Apply normalization to the loss matrix + + + Parameters + ---------- + C : np.array (n1, n2) + The cost matrix to normalize. + norm : str + type of normalization from 'median','max','log','loglog'. Any other + value do not normalize. + + + Returns + ------- + + C : np.array (n1, n2) + The input cost matrix normalized according to given norm. + + """ + + if norm == "median": + C /= float(np.median(C)) + elif norm == "max": + C /= float(np.max(C)) + elif norm == "log": + C = np.log(1 + C) + elif norm == "loglog": + C = np.log(1 + np.log(1 + C)) + + return C + + def dots(*args): """ dots function for multiple matrix multiply """ return reduce(np.dot, args) |