From 1a4c264cc9b2cb0bb89840ee9175177e86eef3ef Mon Sep 17 00:00:00 2001 From: ievred Date: Wed, 8 Apr 2020 16:34:39 +0200 Subject: added label normalization to utils --- ot/utils.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) (limited to 'ot/utils.py') diff --git a/ot/utils.py b/ot/utils.py index b71458b..c154f99 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -200,6 +200,28 @@ def dots(*args): return reduce(np.dot, args) +def label_normalization(y, start=0): + """ Transform labels to start at a given value + + Parameters + ---------- + y : array-like, shape (n, ) + The vector of labels to be normalized. + start : int + Desired value for the smallest label in y (default=0) + + Returns + ------- + y : array-like, shape (n1, ) + The input vector of labels normalized according to given start value. + """ + + diff = np.min(np.unique(y)) - start + if diff != 0: + y -= diff + return y + + def fun(f, q_in, q_out): """ Utility function for parmap with no serializing problems """ while True: -- cgit v1.2.3