summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorngayraud <nat.gayraud@gmail.com>2019-08-12 16:37:58 -0400
committerngayraud <nat.gayraud@gmail.com>2019-08-12 16:37:58 -0400
commit9d4b786a036ac95989825beec819521089fb4feb (patch)
tree3d9fcc4fd26e5d8dbe100d79eddf0776801df33a
parent092866815cf906012f9194b87af1e7ae0270f7e7 (diff)
fixes for travis, added test, minor nits
-rw-r--r--.travis.yml5
-rw-r--r--ot/da.py2
-rw-r--r--ot/utils.py4
-rw-r--r--test/test_da.py73
4 files changed, 80 insertions, 4 deletions
diff --git a/.travis.yml b/.travis.yml
index 5e5694b..72fd29a 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -13,7 +13,7 @@ matrix:
python: 3.5
- os: linux
sudo: required
- python: 3.6
+ python: 3.6
- os: linux
sudo: required
python: 2.7
@@ -21,7 +21,6 @@ before_install:
- ./.travis/before_install.sh
before_script: # configure a headless display to test plot generation
- "export DISPLAY=:99.0"
- - "sh -e /etc/init.d/xvfb start"
- sleep 3 # give xvfb some time to start
# command to install dependencies
install:
@@ -30,6 +29,8 @@ install:
- pip install flake8 pytest "pytest-cov<2.6"
- pip install .
# command to run tests + check syntax style
+services:
+ - xvfb
script:
- python setup.py develop
- flake8 examples/ ot/ test/
diff --git a/ot/da.py b/ot/da.py
index c1d9849..2af855d 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -1852,7 +1852,7 @@ class UnbalancedSinkhornTransport(BaseTransport):
"""
def __init__(self, reg_e=1., reg_m=0.1, method='sinkhorn',
- max_iter=10, tol=10e-9, verbose=False, log=False,
+ max_iter=10, tol=1e-9, verbose=False, log=False,
metric="sqeuclidean", norm=None,
distribution_estimation=distribution_estimation_uniform,
out_of_sample_map='ferradans', limit_max=10):
diff --git a/ot/utils.py b/ot/utils.py
index be839f8..a334fea 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -178,7 +178,9 @@ def cost_normalization(C, norm=None):
The input cost matrix normalized according to given norm.
"""
- if norm == "median":
+ if norm is None:
+ pass
+ elif norm == "median":
C /= float(np.median(C))
elif norm == "max":
C /= float(np.max(C))
diff --git a/test/test_da.py b/test/test_da.py
index f7f3a9d..9efd2d9 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -245,6 +245,79 @@ def test_sinkhorn_transport_class():
assert len(otda.log_.keys()) != 0
+def test_unbalanced_sinkhorn_transport_class():
+ """test_sinkhorn_transport
+ """
+
+ ns = 150
+ nt = 200
+
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
+
+ otda = ot.da.UnbalancedSinkhornTransport()
+
+ # test its computed
+ otda.fit(Xs=Xs, Xt=Xt)
+ assert hasattr(otda, "cost_")
+ assert hasattr(otda, "coupling_")
+ assert hasattr(otda, "log_")
+
+ # test dimensions of coupling
+ assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
+
+ # test margin constraints
+ mu_s = unif(ns)
+ mu_t = unif(nt)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+
+ # test transform
+ transp_Xs = otda.transform(Xs=Xs)
+ assert_equal(transp_Xs.shape, Xs.shape)
+
+ Xs_new, _ = make_data_classif('3gauss', ns + 1)
+ transp_Xs_new = otda.transform(Xs_new)
+
+ # check that the oos method is working
+ assert_equal(transp_Xs_new.shape, Xs_new.shape)
+
+ # test inverse transform
+ transp_Xt = otda.inverse_transform(Xt=Xt)
+ assert_equal(transp_Xt.shape, Xt.shape)
+
+ Xt_new, _ = make_data_classif('3gauss2', nt + 1)
+ transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
+
+ # check that the oos method is working
+ assert_equal(transp_Xt_new.shape, Xt_new.shape)
+
+ # test fit_transform
+ transp_Xs = otda.fit_transform(Xs=Xs, Xt=Xt)
+ assert_equal(transp_Xs.shape, Xs.shape)
+
+ # test unsupervised vs semi-supervised mode
+ otda_unsup = ot.da.SinkhornTransport()
+ otda_unsup.fit(Xs=Xs, Xt=Xt)
+ n_unsup = np.sum(otda_unsup.cost_)
+
+ otda_semi = ot.da.SinkhornTransport()
+ otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
+ assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ n_semisup = np.sum(otda_semi.cost_)
+
+ # check that the cost matrix norms are indeed different
+ assert n_unsup != n_semisup, "semisupervised mode not working"
+
+ # check everything runs well with log=True
+ otda = ot.da.SinkhornTransport(log=True)
+ otda.fit(Xs=Xs, ys=ys, Xt=Xt)
+ assert len(otda.log_.keys()) != 0
+
+
def test_emd_transport_class():
"""test_sinkhorn_transport
"""