summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorievred <ievgen.redko@univ-st-etienne.fr>2020-04-01 09:00:03 +0200
committerievred <ievgen.redko@univ-st-etienne.fr>2020-04-01 09:00:03 +0200
commit439860609df786a877383775dd901afe28480cc9 (patch)
tree70e94ca177eb5bd873bc6c5324ca5ddada2deba3
parentba493aa5488507937b7f9707faa17128c9aa1872 (diff)
fix imports remove checks
-rw-r--r--ot/da.py5
-rw-r--r--test/test_da.py4
2 files changed, 5 insertions, 4 deletions
diff --git a/ot/da.py b/ot/da.py
index a9c3cea..e62e495 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -14,7 +14,7 @@ Domain adaptation with optimal transport
import numpy as np
import scipy.linalg as linalg
-from .bregman import sinkhorn
+from .bregman import sinkhorn, jcpot_barycenter
from .lp import emd
from .utils import unif, dist, kernel, cost_normalization
from .utils import check_params, BaseEstimator
@@ -1956,8 +1956,7 @@ class JCPOTTransport(BaseTransport):
def __init__(self, reg_e=.1, max_iter=10,
tol=10e-9, verbose=False, log=False,
- metric="sqeuclidean", norm=None,
- distribution_estimation=distribution_estimation_uniform,
+ metric="sqeuclidean",
out_of_sample_map='ferradans'):
self.reg_e = reg_e
diff --git a/test/test_da.py b/test/test_da.py
index 958df7b..7526f30 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -572,7 +572,6 @@ def test_jcpot_transport_class():
# test its computed
otda.fit(Xs=Xs, ys=ys, Xt=Xt)
- print(otda.proportions_)
assert hasattr(otda, "coupling_")
assert hasattr(otda, "proportions_")
@@ -610,3 +609,6 @@ def test_jcpot_transport_class():
# check that the oos method is working
assert_equal(transp_Xs_new.shape, Xs_new.shape)
+
+
+test_jcpot_transport_class() \ No newline at end of file