summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md3
-rw-r--r--ot/bregman.py157
-rw-r--r--ot/da.py190
-rw-r--r--test/test_da.py2
4 files changed, 171 insertions, 181 deletions
diff --git a/README.md b/README.md
index c115776..f439405 100644
--- a/README.md
+++ b/README.md
@@ -29,6 +29,7 @@ It provides the following solvers:
* Non regularized free support Wasserstein barycenters [20].
* Unbalanced OT with KL relaxation distance and barycenter [10, 25].
* Screening Sinkhorn Algorithm for OT [26].
+* JCPOT algorithm for multi-source target shift [27].
Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder.
@@ -257,3 +258,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. (2015). [Learning with a Wasserstein Loss](http://cbcl.mit.edu/wasserstein/) Advances in Neural Information Processing Systems (NIPS).
[26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). [Screening Sinkhorn Algorithm for Regularized Optimal Transport](https://papers.nips.cc/paper/9386-screening-sinkhorn-algorithm-for-regularized-optimal-transport), Advances in Neural Information Processing Systems 33 (NeurIPS).
+
+[27] Redko I., Courty N., Flamary R., Tuia D. (2019). [Optimal Transport for Multi-source Domain Adaptation under Target Shift](http://proceedings.mlr.press/v89/redko19a.html), Proceedings of the Twenty-Second International Conference on Artificial Intelligence and Statistics (AISTATS) 22, 2019. \ No newline at end of file
diff --git a/ot/bregman.py b/ot/bregman.py
index d5e3563..d17aaf0 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -10,6 +10,7 @@ Bregman projections for regularized OT
# Hicham Janati <hicham.janati@inria.fr>
# Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com>
# Alexander Tong <alexander.tong@yale.edu>
+# Ievgen Redko <ievgen.redko@univ-st-etienne.fr>
#
# License: MIT License
@@ -18,7 +19,6 @@ import warnings
from .utils import unif, dist
from scipy.optimize import fmin_l_bfgs_b
-
def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
stopThr=1e-9, verbose=False, log=False, **kwargs):
r"""
@@ -1501,6 +1501,161 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
else:
return np.sum(K0, axis=1)
+def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
+ stopThr=1e-6, verbose=False, log=False, **kwargs):
+ r'''Joint OT and proportion estimation for multi-source target shift as proposed in [27]
+
+ The function solves the following optimization problem:
+
+ .. math::
+
+ \mathbf{h} = arg\min_{\mathbf{h}}\quad \sum_{k=1}^{K} \lambda_k
+ W_{reg}((\mathbf{D}_2^{(k)} \mathbf{h})^T, \mathbf{a})
+
+ s.t. \ \forall k, \mathbf{D}_1^{(k)} \gamma_k \mathbf{1}_n= \mathbf{h}
+
+ where :
+
+ - :math:`\lambda_k` is the weight of k-th source domain
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
+ - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to k-th source domain defined as in [p. 5, 27], its expected shape is `(n_k, C)` where `n_k` is the number of elements in the k-th source domain and `C` is the number of classes
+ - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size C
+ - :math:`\mathbf{a}` is a uniform vector of weights in the target domain of size `n`
+ - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in [p. 5, 27], its expected shape is `(n_k, C)`
+
+ The problem consist in solving a Wasserstein barycenter problem to estimate the proportions :math:`\mathbf{h}` in the target domain.
+
+ The algorithm used for solving the problem is the Iterative Bregman projections algorithm
+ with two sets of marginal constraints related to the unknown vector :math:`\mathbf{h}` and uniform tarhet distribution.
+
+ Parameters
+ ----------
+ Xs : list of K np.ndarray(nsk,d)
+ features of all source domains' samples
+ Ys : list of K np.ndarray(nsk,)
+ labels of all source domains' samples
+ Xt : np.ndarray (nt,d)
+ samples in the target domain
+ reg : float
+ Regularization term > 0
+ metric : string, optional (default="sqeuclidean")
+ The ground metric for the Wasserstein problem
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on relative change in the barycenter (>0)
+ log : bool, optional
+ record log if True
+ verbose : bool, optional (default=False)
+ Controls the verbosity of the optimization algorithm
+
+ Returns
+ -------
+ gamma : List of K (nsk x nt) ndarrays
+ Optimal transportation matrices for the given parameters for each pair of source and target domains
+ h : (C,) ndarray
+ proportion estimation in the target domain
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia
+ "Optimal transport for multi-source domain adaptation under target shift",
+ International Conference on Artificial Intelligence and Statistics (AISTATS), 2019.
+
+ '''
+ nbclasses = len(np.unique(Ys[0]))
+ nbdomains = len(Xs)
+
+ # For each source domain, build cost matrices M, Gibbs kernels K and corresponding matrices D_1 and D_2
+ all_domains = []
+
+ # log dictionary
+ if log:
+ log = {'niter': 0, 'err': [], 'all_domains': []}
+
+ for d in range(nbdomains):
+ dom = {}
+ nsk = Xs[d].shape[0] # get number of elements for this domain
+ dom['nbelem'] = nsk
+ classes = np.unique(Ys[d]) # get number of classes for this domain
+
+ # format classes to start from 0 for convenience
+ if np.min(classes) != 0:
+ Ys[d] = Ys[d] - np.min(classes)
+ classes = np.unique(Ys[d])
+
+ # build the corresponding D_1 and D_2 matrices
+ D1 = np.zeros((nbclasses, nsk))
+ D2 = np.zeros((nbclasses, nsk))
+
+ for c in classes:
+ nbelemperclass = np.sum(Ys[d] == c)
+ if nbelemperclass != 0:
+ D1[int(c), Ys[d] == c] = 1.
+ D2[int(c), Ys[d] == c] = 1. / (nbelemperclass)
+ dom['D1'] = D1
+ dom['D2'] = D2
+
+ # build the cost matrix and the Gibbs kernel
+ M = dist(Xs[d], Xt, metric=metric)
+ M = M / np.median(M)
+
+ K = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+ dom['K'] = K
+
+ all_domains.append(dom)
+
+ # uniform target distribution
+ a = unif(np.shape(Xt)[0])
+
+ cpt = 0 # iterations count
+ err = 1
+ old_bary = np.ones((nbclasses))
+
+ while (err > stopThr and cpt < numItermax):
+
+ bary = np.zeros((nbclasses))
+
+ # update coupling matrices for marginal constraints w.r.t. uniform target distribution
+ for d in range(nbdomains):
+ all_domains[d]['K'] = projC(all_domains[d]['K'], a)
+ other = np.sum(all_domains[d]['K'], axis=1)
+ bary = bary + np.log(np.dot(all_domains[d]['D1'], other)) / nbdomains
+
+ bary = np.exp(bary)
+
+ # update coupling matrices for marginal constraints w.r.t. unknown proportions based on [Prop 4., 27]
+ for d in range(nbdomains):
+ new = np.dot(all_domains[d]['D2'].T, bary)
+ all_domains[d]['K'] = projR(all_domains[d]['K'], new)
+
+ err = np.linalg.norm(bary - old_bary)
+ cpt = cpt + 1
+ old_bary = bary
+
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ bary = bary / np.sum(bary)
+ couplings = [all_domains[d]['K'] for d in range(nbdomains)]
+
+ if log:
+ log['niter'] = cpt
+ log['all_domains'] = all_domains
+ return couplings, bary, log
+ else:
+ return couplings, bary
def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
numIterMax=10000, stopThr=1e-9, verbose=False,
diff --git a/ot/da.py b/ot/da.py
index a3da8c1..a9c3cea 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -7,20 +7,20 @@ Domain adaptation with optimal transport
# Nicolas Courty <ncourty@irisa.fr>
# Michael Perrot <michael.perrot@univ-st-etienne.fr>
# Nathalie Gayraud <nat.gayraud@gmail.com>
+# Ievgen Redko <ievgen.redko@univ-st-etienne.fr>
#
# License: MIT License
import numpy as np
import scipy.linalg as linalg
-from .bregman import sinkhorn, projR, projC
+from .bregman import sinkhorn
from .lp import emd
from .utils import unif, dist, kernel, cost_normalization
from .utils import check_params, BaseEstimator
from .unbalanced import sinkhorn_unbalanced
from .optim import cg
from .optim import gcg
-from functools import reduce
def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
@@ -128,7 +128,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
W = np.ones(M.shape)
for (i, c) in enumerate(classes):
majs = np.sum(transp[indices_labels[i]], axis=0)
- majs = p * ((majs + epsilon)**(p - 1))
+ majs = p * ((majs + epsilon) ** (p - 1))
W[indices_labels[i]] = majs
return transp
@@ -360,8 +360,8 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
def loss(L, G):
"""Compute full loss"""
- return np.sum((xs1.dot(L) - ns * G.dot(xt))**2) + mu * \
- np.sum(G * M) + eta * np.sum(sel(L - I0)**2)
+ return np.sum((xs1.dot(L) - ns * G.dot(xt)) ** 2) + mu * \
+ np.sum(G * M) + eta * np.sum(sel(L - I0) ** 2)
def solve_L(G):
""" solve L problem with fixed G (least square)"""
@@ -373,10 +373,11 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
xsi = xs1.dot(L)
def f(G):
- return np.sum((xsi - ns * G.dot(xt))**2)
+ return np.sum((xsi - ns * G.dot(xt)) ** 2)
def df(G):
return -2 * ns * (xsi - ns * G.dot(xt)).dot(xt.T)
+
G = cg(a, b, M, 1.0 / mu, f, df, G0=G0,
numItermax=numInnerItermax, stopThr=stopInnerThr)
return G
@@ -563,8 +564,8 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
def loss(L, G):
"""Compute full loss"""
- return np.sum((K1.dot(L) - ns * G.dot(xt))**2) + mu * \
- np.sum(G * M) + eta * np.trace(L.T.dot(Kreg).dot(L))
+ return np.sum((K1.dot(L) - ns * G.dot(xt)) ** 2) + mu * \
+ np.sum(G * M) + eta * np.trace(L.T.dot(Kreg).dot(L))
def solve_L_nobias(G):
""" solve L problem with fixed G (least square)"""
@@ -581,10 +582,11 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
xsi = K1.dot(L)
def f(G):
- return np.sum((xsi - ns * G.dot(xt))**2)
+ return np.sum((xsi - ns * G.dot(xt)) ** 2)
def df(G):
return -2 * ns * (xsi - ns * G.dot(xt)).dot(xt.T)
+
G = cg(a, b, M, 1.0 / mu, f, df, G0=G0,
numItermax=numInnerItermax, stopThr=stopInnerThr)
return G
@@ -746,163 +748,6 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
return A, b
-def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
- stopThr=1e-6, verbose=False, log=False, **kwargs):
- r'''Joint OT and proportion estimation for multi-source target shift as proposed in [27]
-
- The function solves the following optimization problem:
-
- .. math::
-
- \mathbf{h} = arg\min_{\mathbf{h}}\quad \sum_{k=1}^{K} \lambda_k
- W_{reg}((\mathbf{D}_2^{(k)} \mathbf{h})^T, \mathbf{a})
-
- s.t. \ \forall k, \mathbf{D}_1^{(k)} \gamma_k \mathbf{1}_n= \mathbf{h}
-
- where :
-
- - :math:`\lambda_k` is the weight of k-th source domain
- - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
- - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to k-th source domain defined as in [p. 5, 27], its expected shape is `(n_k, C)` where `n_k` is the number of elements in the k-th source domain and `C` is the number of classes
- - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size C
- - :math:`\mathbf{a}` is a uniform vector of weights in the target domain of size `n`
- - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in [p. 5, 27], its expected shape is `(n_k, C)`
-
- The problem consist in solving a Wasserstein barycenter problem to estimate the proportions :math:`\mathbf{h}` in the target domain.
-
- The algorithm used for solving the problem is the Iterative Bregman projections algorithm
- with two sets of marginal constraints related to the unknown vector :math:`\mathbf{h}` and uniform tarhet distribution.
-
- Parameters
- ----------
- Xs : list of K np.ndarray(nsk,d)
- features of all source domains' samples
- Ys : list of K np.ndarray(nsk,)
- labels of all source domains' samples
- Xt : np.ndarray (nt,d)
- samples in the target domain
- reg : float
- Regularization term > 0
- metric : string, optional (default="sqeuclidean")
- The ground metric for the Wasserstein problem
- numItermax : int, optional
- Max number of iterations
- stopThr : float, optional
- Stop threshold on relative change in the barycenter (>0)
- log : bool, optional
- record log if True
- verbose : bool, optional (default=False)
- Controls the verbosity of the optimization algorithm
-
- Returns
- -------
- gamma : List of K (nsk x nt) ndarrays
- Optimal transportation matrices for the given parameters for each pair of source and target domains
- h : (C,) ndarray
- proportion estimation in the target domain
- log : dict
- log dictionary return only if log==True in parameters
-
-
- References
- ----------
-
- .. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia
- "Optimal transport for multi-source domain adaptation under target shift",
- International Conference on Artificial Intelligence and Statistics (AISTATS), 2019.
-
- '''
- nbclasses = len(np.unique(Ys[0]))
- nbdomains = len(Xs)
-
- # For each source domain, build cost matrices M, Gibbs kernels K and corresponding matrices D_1 and D_2
- all_domains = []
-
- # log dictionary
- if log:
- log = {'niter': 0, 'err': [], 'all_domains': []}
-
- for d in range(nbdomains):
- dom = {}
- nsk = Xs[d].shape[0] # get number of elements for this domain
- dom['nbelem'] = nsk
- classes = np.unique(Ys[d]) # get number of classes for this domain
-
- # format classes to start from 0 for convenience
- if np.min(classes) != 0:
- Ys[d] = Ys[d] - np.min(classes)
- classes = np.unique(Ys[d])
-
- # build the corresponding D_1 and D_2 matrices
- D1 = np.zeros((nbclasses, nsk))
- D2 = np.zeros((nbclasses, nsk))
-
- for c in classes:
- nbelemperclass = np.sum(Ys[d] == c)
- if nbelemperclass != 0:
- D1[int(c), Ys[d] == c] = 1.
- D2[int(c), Ys[d] == c] = 1. / (nbelemperclass)
- dom['D1'] = D1
- dom['D2'] = D2
-
- # build the cost matrix and the Gibbs kernel
- M = dist(Xs[d], Xt, metric=metric)
- M = M / np.median(M)
-
- K = np.empty(M.shape, dtype=M.dtype)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
- dom['K'] = K
-
- all_domains.append(dom)
-
- # uniform target distribution
- a = unif(np.shape(Xt)[0])
-
- cpt = 0 # iterations count
- err = 1
- old_bary = np.ones((nbclasses))
-
- while (err > stopThr and cpt < numItermax):
-
- bary = np.zeros((nbclasses))
-
- # update coupling matrices for marginal constraints w.r.t. uniform target distribution
- for d in range(nbdomains):
- all_domains[d]['K'] = projC(all_domains[d]['K'], a)
- other = np.sum(all_domains[d]['K'], axis=1)
- bary = bary + np.log(np.dot(all_domains[d]['D1'], other)) / nbdomains
-
- bary = np.exp(bary)
-
- # update coupling matrices for marginal constraints w.r.t. unknown proportions based on [Prop 4., 27]
- for d in range(nbdomains):
- new = np.dot(all_domains[d]['D2'].T, bary)
- all_domains[d]['K'] = projR(all_domains[d]['K'], new)
-
- err = np.linalg.norm(bary - old_bary)
- cpt = cpt + 1
- old_bary = bary
-
- if log:
- log['err'].append(err)
-
- if verbose:
- if cpt % 200 == 0:
- print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
-
- bary = bary / np.sum(bary)
- couplings = [all_domains[d]['K'] for d in range(nbdomains)]
-
- if log:
- log['niter'] = cpt
- log['all_domains'] = all_domains
- return couplings, bary, log
- else:
- return couplings, bary
-
-
def distribution_estimation_uniform(X):
"""estimates a uniform distribution from an array of samples X
@@ -921,7 +766,6 @@ def distribution_estimation_uniform(X):
class BaseTransport(BaseEstimator):
-
"""Base class for OTDA objects
Notes
@@ -1079,7 +923,6 @@ class BaseTransport(BaseEstimator):
transp_Xs = []
for bi in batch_ind:
-
# get the nearest neighbor in the source domain
D0 = dist(Xs[bi], self.xs_)
idx = np.argmin(D0, axis=1)
@@ -1148,7 +991,6 @@ class BaseTransport(BaseEstimator):
transp_Xt = []
for bi in batch_ind:
-
D0 = dist(Xt[bi], self.xt_)
idx = np.argmin(D0, axis=1)
@@ -1294,7 +1136,6 @@ class LinearTransport(BaseTransport):
# check the necessary inputs parameters are here
if check_params(Xs=Xs):
-
transp_Xs = Xs.dot(self.A_) + self.B_
return transp_Xs
@@ -1328,14 +1169,12 @@ class LinearTransport(BaseTransport):
# check the necessary inputs parameters are here
if check_params(Xt=Xt):
-
transp_Xt = Xt.dot(self.A1_) + self.B1_
return transp_Xt
class SinkhornTransport(BaseTransport):
-
"""Domain Adapatation OT method based on Sinkhorn Algorithm
Parameters
@@ -1445,7 +1284,6 @@ class SinkhornTransport(BaseTransport):
class EMDTransport(BaseTransport):
-
"""Domain Adapatation OT method based on Earth Mover's Distance
Parameters
@@ -1537,7 +1375,6 @@ class EMDTransport(BaseTransport):
class SinkhornLpl1Transport(BaseTransport):
-
"""Domain Adapatation OT method based on sinkhorn algorithm +
LpL1 class regularization.
@@ -1639,7 +1476,6 @@ class SinkhornLpl1Transport(BaseTransport):
# check the necessary inputs parameters are here
if check_params(Xs=Xs, Xt=Xt, ys=ys):
-
super(SinkhornLpl1Transport, self).fit(Xs, ys, Xt, yt)
returned_ = sinkhorn_lpl1_mm(
@@ -1658,7 +1494,6 @@ class SinkhornLpl1Transport(BaseTransport):
class SinkhornL1l2Transport(BaseTransport):
-
"""Domain Adapatation OT method based on sinkhorn algorithm +
l1l2 class regularization.
@@ -1782,7 +1617,6 @@ class SinkhornL1l2Transport(BaseTransport):
class MappingTransport(BaseEstimator):
-
"""MappingTransport: DA methods that aims at jointly estimating a optimal
transport coupling and the associated mapping
@@ -1956,7 +1790,6 @@ class MappingTransport(BaseEstimator):
class UnbalancedSinkhornTransport(BaseTransport):
-
"""Domain Adapatation unbalanced OT method based on sinkhorn algorithm
Parameters
@@ -2075,7 +1908,6 @@ class UnbalancedSinkhornTransport(BaseTransport):
class JCPOTTransport(BaseTransport):
-
"""Domain Adapatation OT method for multi-source target shift based on Wasserstein barycenter algorithm.
Parameters
diff --git a/test/test_da.py b/test/test_da.py
index a8c258a..958df7b 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -8,6 +8,7 @@ import numpy as np
from numpy.testing import assert_allclose, assert_equal
import ot
+from ot.bregman import jcpot_barycenter
from ot.datasets import make_data_classif
from ot.utils import unif
@@ -603,7 +604,6 @@ def test_jcpot_transport_class():
# test transform
transp_Xs = otda.transform(Xs=Xs)
[assert_equal(x.shape, y.shape) for x, y in zip(transp_Xs, Xs)]
- #assert_equal(transp_Xs.shape, Xs.shape)
Xs_new, _ = make_data_classif('3gauss', ns1 + 1)
transp_Xs_new = otda.transform(Xs_new)