summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorSlasnista <stan.chambon@gmail.com>2017-07-28 14:52:36 +0200
committerSlasnista <stan.chambon@gmail.com>2017-07-28 14:52:36 +0200
commit84adadd69aef12826aa3970d135d499ef4d13f64 (patch)
treedbda4e899026ab6175c5f847f5bf59a980a15b2e /ot/da.py
parentca9c9d6d8ecef6a38e0fd6240538a8af35ad06f5 (diff)
small modifs according to NG proposals
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py43
1 files changed, 31 insertions, 12 deletions
diff --git a/ot/da.py b/ot/da.py
index f534bf5..a422f7c 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -926,8 +926,8 @@ from sklearn.metrics import pairwise_distances
"""
- all methods have the same input parameters: Xs, Xt, ys, yt (what order ?)
-- ref: is the entropic reg parameter
-- eta: is the second reg parameter
+- reg_e: is the entropic reg parameter
+- reg_cl: is the second reg parameter
- gamma_: is the optimal coupling
- mapping barycentric for the moment
@@ -940,7 +940,7 @@ Questions:
class BaseTransport(BaseEstimator):
- def fit(self, Xs=None, ys=None, Xt=None, yt=None, method="sinkhorn"):
+ def fit(self, Xs=None, ys=None, Xt=None, yt=None, method=None):
"""fit: estimates the optimal coupling
Parameters:
@@ -964,13 +964,17 @@ class BaseTransport(BaseEstimator):
print("TODO: modify cost matrix accordingly")
pass
- # distribution estimation: should we change it ?
- mu_s = np.ones(Xs.shape[0]) / float(Xs.shape[0])
- mu_t = np.ones(Xt.shape[0]) / float(Xt.shape[0])
+ # distribution estimation
+ if self.distribution == "uniform":
+ mu_s = np.ones(Xs.shape[0]) / float(Xs.shape[0])
+ mu_t = np.ones(Xt.shape[0]) / float(Xt.shape[0])
+ else:
+ print("TODO: implement kernelized approach")
+ # coupling estimation
if method == "sinkhorn":
self.gamma_ = sinkhorn(
- a=mu_s, b=mu_t, M=Cost, reg=self.reg,
+ a=mu_s, b=mu_t, M=Cost, reg=self.reg_e,
numItermax=self.max_iter, stopThr=self.tol,
verbose=self.verbose, log=self.log)
else:
@@ -1058,7 +1062,7 @@ class SinkhornTransport(BaseTransport):
Parameters
----------
- - reg : parameter for entropic regularization
+ - reg_e : parameter for entropic regularization
- mode: unsupervised (default) or semi supervised: controls whether
labels are taken into accout to construct the optimal coupling
- max_iter : maximum number of iterations
@@ -1071,10 +1075,10 @@ class SinkhornTransport(BaseTransport):
- gamma_: optimal coupling estimated by the fit function
"""
- def __init__(self, reg=1., mode="unsupervised", max_iter=1000,
+ def __init__(self, reg_e=1., mode="unsupervised", max_iter=1000,
tol=10e-9, verbose=False, log=False, mapping="barycentric",
- metric="sqeuclidean"):
- self.reg = reg
+ metric="sqeuclidean", distribution="uniform"):
+ self.reg_e = reg_e
self.mode = mode
self.max_iter = max_iter
self.tol = tol
@@ -1082,11 +1086,26 @@ class SinkhornTransport(BaseTransport):
self.log = log
self.mapping = mapping
self.metric = metric
+ self.distribution = distribution
self.method = "sinkhorn"
def fit(self, Xs=None, ys=None, Xt=None, yt=None):
- """_fit
+ """fit
+
+ Parameters:
+ -----------
+ - Xs: source samples, (ns samples, d features) numpy-like array
+ - ys: source labels
+ - Xt: target samples (nt samples, d features) numpy-like array
+ - yt: target labels
+ - method: algorithm to use to compute optimal coupling
+ (default: sinkhorn)
+
+ Returns:
+ --------
+ - self
"""
+
return super(SinkhornTransport, self).fit(
Xs, ys, Xt, yt, method=self.method)