diff options
author | ievred <ievgen.redko@univ-st-etienne.fr> | 2020-04-01 09:00:03 +0200 |
---|---|---|
committer | ievred <ievgen.redko@univ-st-etienne.fr> | 2020-04-01 09:00:03 +0200 |
commit | 439860609df786a877383775dd901afe28480cc9 (patch) | |
tree | 70e94ca177eb5bd873bc6c5324ca5ddada2deba3 | |
parent | ba493aa5488507937b7f9707faa17128c9aa1872 (diff) |
fix imports remove checks
-rw-r--r-- | ot/da.py | 5 | ||||
-rw-r--r-- | test/test_da.py | 4 |
2 files changed, 5 insertions, 4 deletions
@@ -14,7 +14,7 @@ Domain adaptation with optimal transport import numpy as np import scipy.linalg as linalg -from .bregman import sinkhorn +from .bregman import sinkhorn, jcpot_barycenter from .lp import emd from .utils import unif, dist, kernel, cost_normalization from .utils import check_params, BaseEstimator @@ -1956,8 +1956,7 @@ class JCPOTTransport(BaseTransport): def __init__(self, reg_e=.1, max_iter=10, tol=10e-9, verbose=False, log=False, - metric="sqeuclidean", norm=None, - distribution_estimation=distribution_estimation_uniform, + metric="sqeuclidean", out_of_sample_map='ferradans'): self.reg_e = reg_e diff --git a/test/test_da.py b/test/test_da.py index 958df7b..7526f30 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -572,7 +572,6 @@ def test_jcpot_transport_class(): # test its computed otda.fit(Xs=Xs, ys=ys, Xt=Xt) - print(otda.proportions_) assert hasattr(otda, "coupling_") assert hasattr(otda, "proportions_") @@ -610,3 +609,6 @@ def test_jcpot_transport_class(): # check that the oos method is working assert_equal(transp_Xs_new.shape, Xs_new.shape) + + +test_jcpot_transport_class()
\ No newline at end of file |