summaryrefslogtreecommitdiff
path: root/ot/utils.py
diff options
context:
space:
mode:
authoraje <leo_g_autheron@hotmail.fr>2017-08-30 10:53:31 +0200
committeraje <leo_g_autheron@hotmail.fr>2017-08-30 10:53:31 +0200
commit0316d552fa6005aaf0f6231eb9ca20441d5a2532 (patch)
tree5faf38ccb37e8476bd9b5afd807fa3de69d7d6c8 /ot/utils.py
parent5bbea9c608e66b499f9858af408cc65c07cf4ac2 (diff)
Move normalize function in utils.py
Diffstat (limited to 'ot/utils.py')
-rw-r--r--ot/utils.py33
1 files changed, 33 insertions, 0 deletions
diff --git a/ot/utils.py b/ot/utils.py
index 01f2a67..31a002b 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -134,6 +134,39 @@ def dist0(n, method='lin_square'):
return res
+def cost_normalization(C, norm=None):
+ """ Apply normalization to the loss matrix
+
+
+ Parameters
+ ----------
+ C : np.array (n1, n2)
+ The cost matrix to normalize.
+ norm : str
+ type of normalization from 'median','max','log','loglog'. Any other
+ value do not normalize.
+
+
+ Returns
+ -------
+
+ C : np.array (n1, n2)
+ The input cost matrix normalized according to given norm.
+
+ """
+
+ if norm == "median":
+ C /= float(np.median(C))
+ elif norm == "max":
+ C /= float(np.max(C))
+ elif norm == "log":
+ C = np.log(1 + C)
+ elif norm == "loglog":
+ C = np.log(1 + np.log(1 + C))
+
+ return C
+
+
def dots(*args):
""" dots function for multiple matrix multiply """
return reduce(np.dot, args)