summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorSlasnista <stan.chambon@gmail.com>2017-08-01 10:42:09 +0200
committerSlasnista <stan.chambon@gmail.com>2017-08-01 10:42:09 +0200
commit122b5bf2c0c8b6ff7b46adf19c7dd72e62c85b1f (patch)
tree8acd53375c070e1fdefa8e458aa0acf0683a8c3d /ot/da.py
parentbd7c7d2534980d3105d060dd24a444433422134d (diff)
update SinkhornTransport class + added test for class
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py56
1 files changed, 21 insertions, 35 deletions
diff --git a/ot/da.py b/ot/da.py
index d30c821..6b98a17 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -15,6 +15,7 @@ from .lp import emd
from .utils import unif, dist, kernel
from .optim import cg
from .optim import gcg
+import warnings
def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
@@ -921,15 +922,8 @@ class OTDA_mapping_kernel(OTDA_mapping_linear):
# proposal
##############################################################################
-# from sklearn.base import BaseEstimator
-# from sklearn.metrics import pairwise_distances
-
-##############################################################################
-# adapted from scikit-learn
-
-import warnings
-# from .externals.six import string_types, iteritems
+# adapted from sklearn
class BaseEstimator(object):
"""Base class for all estimators in scikit-learn
@@ -1067,7 +1061,7 @@ def distribution_estimation_uniform(X):
The uniform distribution estimated from X
"""
- return np.ones(X.shape[0]) / float(X.shape[0])
+ return unif(X.shape[0])
class BaseTransport(BaseEstimator):
@@ -1092,29 +1086,20 @@ class BaseTransport(BaseEstimator):
"""
# pairwise distance
- Cost = dist(Xs, Xt, metric=self.metric)
+ self.Cost = dist(Xs, Xt, metric=self.metric)
if self.mode == "semisupervised":
print("TODO: modify cost matrix accordingly")
pass
# distribution estimation
- mu_s = self.distribution_estimation(Xs)
- mu_t = self.distribution_estimation(Xt)
+ self.mu_s = self.distribution_estimation(Xs)
+ self.mu_t = self.distribution_estimation(Xt)
# store arrays of samples
self.Xs = Xs
self.Xt = Xt
- # coupling estimation
- if self.method == "sinkhorn":
- self.gamma_ = sinkhorn(
- a=mu_s, b=mu_t, M=Cost, reg=self.reg_e,
- numItermax=self.max_iter, stopThr=self.tol,
- verbose=self.verbose, log=self.log)
- else:
- print("TODO: implement the other methods")
-
return self
def fit_transform(self, Xs=None, ys=None, Xt=None, yt=None):
@@ -1157,8 +1142,7 @@ class BaseTransport(BaseEstimator):
The transport source samples.
"""
- # TODO: check whether Xs is new or not
- if self.Xs == Xs:
+ if np.array_equal(self.Xs, Xs):
# perform standard barycentric mapping
transp = self.gamma_ / np.sum(self.gamma_, 1)[:, None]
@@ -1169,7 +1153,9 @@ class BaseTransport(BaseEstimator):
transp_Xs = np.dot(transp, self.Xt)
else:
# perform out of sample mapping
- print("out of sample mapping not yet implemented")
+ print("Warning: out of sample mapping not yet implemented")
+ print("input data will be returned")
+ transp_Xs = Xs
return transp_Xs
@@ -1191,8 +1177,7 @@ class BaseTransport(BaseEstimator):
The transported target samples.
"""
- # TODO: check whether Xt is new or not
- if self.Xt == Xt:
+ if np.array_equal(self.Xt, Xt):
# perform standard barycentric mapping
transp_ = self.gamma_.T / np.sum(self.gamma_, 0)[:, None]
@@ -1203,7 +1188,9 @@ class BaseTransport(BaseEstimator):
transp_Xt = np.dot(transp_, self.Xs)
else:
# perform out of sample mapping
- print("out of sample mapping not yet implemented")
+ print("Warning: out of sample mapping not yet implemented")
+ print("input data will be returned")
+ transp_Xt = Xt
return transp_Xt
@@ -1254,7 +1241,7 @@ class SinkhornTransport(BaseTransport):
"""
def __init__(self, reg_e=1., mode="unsupervised", max_iter=1000,
- tol=10e-9, verbose=False, log=False, mapping="barycentric",
+ tol=10e-9, verbose=False, log=False,
metric="sqeuclidean",
distribution_estimation=distribution_estimation_uniform,
out_of_sample_map='ferradans'):
@@ -1265,7 +1252,6 @@ class SinkhornTransport(BaseTransport):
self.tol = tol
self.verbose = verbose
self.log = log
- self.mapping = mapping
self.metric = metric
self.distribution_estimation = distribution_estimation
self.method = "sinkhorn"
@@ -1290,10 +1276,10 @@ class SinkhornTransport(BaseTransport):
Returns self.
"""
- return super(SinkhornTransport, self).fit(Xs, ys, Xt, yt)
-
+ self = super(SinkhornTransport, self).fit(Xs, ys, Xt, yt)
-if __name__ == "__main__":
- print("Small test")
-
- st = SinkhornTransport()
+ # coupling estimation
+ self.gamma_ = sinkhorn(
+ a=self.mu_s, b=self.mu_t, M=self.Cost, reg=self.reg_e,
+ numItermax=self.max_iter, stopThr=self.tol,
+ verbose=self.verbose, log=self.log)