From 0316d552fa6005aaf0f6231eb9ca20441d5a2532 Mon Sep 17 00:00:00 2001 From: aje Date: Wed, 30 Aug 2017 10:53:31 +0200 Subject: Move normalize function in utils.py --- ot/utils.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) (limited to 'ot/utils.py') diff --git a/ot/utils.py b/ot/utils.py index 01f2a67..31a002b 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -134,6 +134,39 @@ def dist0(n, method='lin_square'): return res +def cost_normalization(C, norm=None): + """ Apply normalization to the loss matrix + + + Parameters + ---------- + C : np.array (n1, n2) + The cost matrix to normalize. + norm : str + type of normalization from 'median','max','log','loglog'. Any other + value do not normalize. + + + Returns + ------- + + C : np.array (n1, n2) + The input cost matrix normalized according to given norm. + + """ + + if norm == "median": + C /= float(np.median(C)) + elif norm == "max": + C /= float(np.max(C)) + elif norm == "log": + C = np.log(1 + C) + elif norm == "loglog": + C = np.log(1 + np.log(1 + C)) + + return C + + def dots(*args): """ dots function for multiple matrix multiply """ return reduce(np.dot, args) -- cgit v1.2.3