summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-07-11 12:05:07 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-07-11 12:05:07 +0200
commit4efdda7853ab7c0eab17b947e28e416f2b16dc51 (patch)
treef3bb54cec4c32d6b57d6f7137648bd9af4730bd6 /ot/da.py
parent57330c5bd7dca7f315cec4c4f692737cae580ec6 (diff)
add documentation
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py13
1 files changed, 10 insertions, 3 deletions
diff --git a/ot/da.py b/ot/da.py
index 557e2aa..ddf1c60 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -670,10 +670,16 @@ class OTDA(object):
return xf[idx,:]+x-x0[idx,:] # aply the delta to the interpolation
def normalizeM(self, norm):
+ """ Apply normalization to the loss matrix
+
+
+ Parameters
+ ----------
+ norm : str
+ type of normalization from 'median','max','log','loglog'
+
"""
- It may help to normalize the cost matrix self.M if there are numerical
- errors during the sinkhorn based algorithms.
- """
+
if norm == "median":
self.M /= float(np.median(self.M))
elif norm == "max":
@@ -682,6 +688,7 @@ class OTDA(object):
self.M = np.log(1 + self.M)
elif norm == "loglog":
self.M = np.log(1 + np.log(1 + self.M))
+
class OTDA_sinkhorn(OTDA):