summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authoraje <leo_g_autheron@hotmail.fr>2017-08-29 15:38:11 +0200
committerNicolas Courty <Nico@MacBook-Pro-de-Nicolas.local>2017-09-01 11:09:13 +0200
commit6ae3ad7bb48b1fa8964cfd2791bdb86267776495 (patch)
treecb412b0e9a21b6a7c1d0622d84ce8550aaaee60f /ot/da.py
parent5a9795f08341458bd9e3befe0c2c6ea6fa891323 (diff)
Changes to LP solver:
- Allow to modify the maximal number of iterations - Display an error message in the python console if the solver encountered an issue
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/ot/da.py b/ot/da.py
index 78dc150..0dfd02f 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -658,7 +658,7 @@ class OTDA(object):
self.metric = metric
self.computed = False
- def fit(self, xs, xt, ws=None, wt=None, norm=None):
+ def fit(self, xs, xt, ws=None, wt=None, norm=None, numItermax=10000):
"""Fit domain adaptation between samples is xs and xt
(with optional weights)"""
self.xs = xs
@@ -674,7 +674,7 @@ class OTDA(object):
self.M = dist(xs, xt, metric=self.metric)
self.normalizeM(norm)
- self.G = emd(ws, wt, self.M)
+ self.G = emd(ws, wt, self.M, numItermax)
self.computed = True
def interp(self, direction=1):