summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ot/da.py56
-rw-r--r--test/test_da.py51
2 files changed, 72 insertions, 35 deletions
diff --git a/ot/da.py b/ot/da.py
index d30c821..6b98a17 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -15,6 +15,7 @@ from .lp import emd
from .utils import unif, dist, kernel
from .optim import cg
from .optim import gcg
+import warnings
def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
@@ -921,15 +922,8 @@ class OTDA_mapping_kernel(OTDA_mapping_linear):
# proposal
##############################################################################
-# from sklearn.base import BaseEstimator
-# from sklearn.metrics import pairwise_distances
-
-##############################################################################
-# adapted from scikit-learn
-
-import warnings
-# from .externals.six import string_types, iteritems
+# adapted from sklearn
class BaseEstimator(object):
"""Base class for all estimators in scikit-learn
@@ -1067,7 +1061,7 @@ def distribution_estimation_uniform(X):
The uniform distribution estimated from X
"""
- return np.ones(X.shape[0]) / float(X.shape[0])
+ return unif(X.shape[0])
class BaseTransport(BaseEstimator):
@@ -1092,29 +1086,20 @@ class BaseTransport(BaseEstimator):
"""
# pairwise distance
- Cost = dist(Xs, Xt, metric=self.metric)
+ self.Cost = dist(Xs, Xt, metric=self.metric)
if self.mode == "semisupervised":
print("TODO: modify cost matrix accordingly")
pass
# distribution estimation
- mu_s = self.distribution_estimation(Xs)
- mu_t = self.distribution_estimation(Xt)
+ self.mu_s = self.distribution_estimation(Xs)
+ self.mu_t = self.distribution_estimation(Xt)
# store arrays of samples
self.Xs = Xs
self.Xt = Xt
- # coupling estimation
- if self.method == "sinkhorn":
- self.gamma_ = sinkhorn(
- a=mu_s, b=mu_t, M=Cost, reg=self.reg_e,
- numItermax=self.max_iter, stopThr=self.tol,
- verbose=self.verbose, log=self.log)
- else:
- print("TODO: implement the other methods")
-
return self
def fit_transform(self, Xs=None, ys=None, Xt=None, yt=None):
@@ -1157,8 +1142,7 @@ class BaseTransport(BaseEstimator):
The transport source samples.
"""
- # TODO: check whether Xs is new or not
- if self.Xs == Xs:
+ if np.array_equal(self.Xs, Xs):
# perform standard barycentric mapping
transp = self.gamma_ / np.sum(self.gamma_, 1)[:, None]
@@ -1169,7 +1153,9 @@ class BaseTransport(BaseEstimator):
transp_Xs = np.dot(transp, self.Xt)
else:
# perform out of sample mapping
- print("out of sample mapping not yet implemented")
+ print("Warning: out of sample mapping not yet implemented")
+ print("input data will be returned")
+ transp_Xs = Xs
return transp_Xs
@@ -1191,8 +1177,7 @@ class BaseTransport(BaseEstimator):
The transported target samples.
"""
- # TODO: check whether Xt is new or not
- if self.Xt == Xt:
+ if np.array_equal(self.Xt, Xt):
# perform standard barycentric mapping
transp_ = self.gamma_.T / np.sum(self.gamma_, 0)[:, None]
@@ -1203,7 +1188,9 @@ class BaseTransport(BaseEstimator):
transp_Xt = np.dot(transp_, self.Xs)
else:
# perform out of sample mapping
- print("out of sample mapping not yet implemented")
+ print("Warning: out of sample mapping not yet implemented")
+ print("input data will be returned")
+ transp_Xt = Xt
return transp_Xt
@@ -1254,7 +1241,7 @@ class SinkhornTransport(BaseTransport):
"""
def __init__(self, reg_e=1., mode="unsupervised", max_iter=1000,
- tol=10e-9, verbose=False, log=False, mapping="barycentric",
+ tol=10e-9, verbose=False, log=False,
metric="sqeuclidean",
distribution_estimation=distribution_estimation_uniform,
out_of_sample_map='ferradans'):
@@ -1265,7 +1252,6 @@ class SinkhornTransport(BaseTransport):
self.tol = tol
self.verbose = verbose
self.log = log
- self.mapping = mapping
self.metric = metric
self.distribution_estimation = distribution_estimation
self.method = "sinkhorn"
@@ -1290,10 +1276,10 @@ class SinkhornTransport(BaseTransport):
Returns self.
"""
- return super(SinkhornTransport, self).fit(Xs, ys, Xt, yt)
-
+ self = super(SinkhornTransport, self).fit(Xs, ys, Xt, yt)
-if __name__ == "__main__":
- print("Small test")
-
- st = SinkhornTransport()
+ # coupling estimation
+ self.gamma_ = sinkhorn(
+ a=self.mu_s, b=self.mu_t, M=self.Cost, reg=self.reg_e,
+ numItermax=self.max_iter, stopThr=self.tol,
+ verbose=self.verbose, log=self.log)
diff --git a/test/test_da.py b/test/test_da.py
index dfba83f..e7b4ed1 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -6,6 +6,57 @@
import numpy as np
import ot
+from numpy.testing.utils import assert_allclose, assert_equal
+from ot.datasets import get_data_classif
+from ot.utils import unif
+
+np.random.seed(42)
+
+
+def test_sinkhorn_transport():
+ """test_sinkhorn_transport
+ """
+
+ ns = 150
+ nt = 200
+
+ Xs, ys = get_data_classif('3gauss', ns)
+ Xt, yt = get_data_classif('3gauss2', nt)
+
+ clf = ot.da.SinkhornTransport()
+
+ # test its computed
+ clf.fit(Xs=Xs, Xt=Xt)
+
+ # test dimensions of coupling
+ assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(clf.gamma_.shape, ((Xs.shape[0], Xt.shape[0])))
+
+ # test margin constraints
+ mu_s = unif(ns)
+ mu_t = unif(nt)
+ assert_allclose(np.sum(clf.gamma_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ assert_allclose(np.sum(clf.gamma_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+
+ # test transform
+ transp_Xs = clf.transform(Xs=Xs)
+ assert_equal(transp_Xs.shape, Xs.shape)
+
+ Xs_new, _ = get_data_classif('3gauss', ns + 1)
+ transp_Xs_new = clf.transform(Xs_new)
+
+ # check that the oos method is not working
+ assert_equal(transp_Xs_new, Xs_new)
+
+ # test inverse transform
+ transp_Xt = clf.inverse_transform(Xt=Xt)
+ assert_equal(transp_Xt.shape, Xt.shape)
+
+ Xt_new, _ = get_data_classif('3gauss2', nt + 1)
+ transp_Xt_new = clf.inverse_transform(Xt=Xt_new)
+
+ # check that the oos method is not working and returns the input data
+ assert_equal(transp_Xt_new, Xt_new)
def test_otda():