summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2020-04-15 13:16:30 +0200
committerGitHub <noreply@github.com>2020-04-15 13:16:30 +0200
commitadc5570550676b63b9aabb2205a67c5b7c9187f3 (patch)
tree0082b2ea3843bb50738eb4689fb1eb9c74b85034 /ot/bregman.py
parent4cd4e09f89fe6f95a07d632365612b797ab760da (diff)
parent7889484b79a425ebf3632444547a6092e814bf20 (diff)
Merge pull request #137 from ievred/jcpot
[MRG] Jcpot : Multi source DA with target shift
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py224
1 files changed, 198 insertions, 26 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index d5e3563..543dbaa 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -10,6 +10,7 @@ Bregman projections for regularized OT
# Hicham Janati <hicham.janati@inria.fr>
# Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com>
# Alexander Tong <alexander.tong@yale.edu>
+# Ievgen Redko <ievgen.redko@univ-st-etienne.fr>
#
# License: MIT License
@@ -539,12 +540,12 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
old_v = v[i_2]
v[i_2] = b[i_2] / (K[:, i_2].T.dot(u))
G[:, i_2] = u * K[:, i_2] * v[i_2]
- #aviol = (G@one_m - a)
- #aviol_2 = (G.T@one_n - b)
+ # aviol = (G@one_m - a)
+ # aviol_2 = (G.T@one_n - b)
viol += (-old_v + v[i_2]) * K[:, i_2] * u
viol_2[i_2] = v[i_2] * K[:, i_2].dot(u) - b[i_2]
- #print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2)))
+ # print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2)))
if stopThr_val <= stopThr:
break
@@ -940,7 +941,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
# the 10th iterations
transp = G
err = np.linalg.norm(
- (np.sum(transp, axis=0) - b))**2 + np.linalg.norm((np.sum(transp, axis=1) - a))**2
+ (np.sum(transp, axis=0) - b)) ** 2 + np.linalg.norm((np.sum(transp, axis=1) - a)) ** 2
if log:
log['err'].append(err)
@@ -966,7 +967,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
def geometricBar(weights, alldistribT):
"""return the weighted geometric mean of distributions"""
- assert(len(weights) == alldistribT.shape[1])
+ assert (len(weights) == alldistribT.shape[1])
return np.exp(np.dot(np.log(alldistribT), weights.T))
@@ -1108,7 +1109,7 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
if weights is None:
weights = np.ones(A.shape[1]) / A.shape[1]
else:
- assert(len(weights) == A.shape[1])
+ assert (len(weights) == A.shape[1])
if log:
log = {'err': []}
@@ -1206,7 +1207,7 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000,
if weights is None:
weights = np.ones(n_hists) / n_hists
else:
- assert(len(weights) == A.shape[1])
+ assert (len(weights) == A.shape[1])
if log:
log = {'err': []}
@@ -1334,7 +1335,7 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
if weights is None:
weights = np.ones(A.shape[0]) / A.shape[0]
else:
- assert(len(weights) == A.shape[0])
+ assert (len(weights) == A.shape[0])
if log:
log = {'err': []}
@@ -1350,11 +1351,11 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
# this is equivalent to blurring on horizontal then vertical directions
t = np.linspace(0, 1, A.shape[1])
[Y, X] = np.meshgrid(t, t)
- xi1 = np.exp(-(X - Y)**2 / reg)
+ xi1 = np.exp(-(X - Y) ** 2 / reg)
t = np.linspace(0, 1, A.shape[2])
[Y, X] = np.meshgrid(t, t)
- xi2 = np.exp(-(X - Y)**2 / reg)
+ xi2 = np.exp(-(X - Y) ** 2 / reg)
def K(x):
return np.dot(np.dot(xi1, x), xi2)
@@ -1502,6 +1503,164 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
return np.sum(K0, axis=1)
+def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
+ stopThr=1e-6, verbose=False, log=False, **kwargs):
+ r'''Joint OT and proportion estimation for multi-source target shift as proposed in [27]
+
+ The function solves the following optimization problem:
+
+ .. math::
+
+ \mathbf{h} = arg\min_{\mathbf{h}}\quad \sum_{k=1}^{K} \lambda_k
+ W_{reg}((\mathbf{D}_2^{(k)} \mathbf{h})^T, \mathbf{a})
+
+ s.t. \ \forall k, \mathbf{D}_1^{(k)} \gamma_k \mathbf{1}_n= \mathbf{h}
+
+ where :
+
+ - :math:`\lambda_k` is the weight of k-th source domain
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
+ - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to k-th source domain defined as in [p. 5, 27], its expected shape is `(n_k, C)` where `n_k` is the number of elements in the k-th source domain and `C` is the number of classes
+ - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size C
+ - :math:`\mathbf{a}` is a uniform vector of weights in the target domain of size `n`
+ - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in [p. 5, 27], its expected shape is `(n_k, C)`
+
+ The problem consist in solving a Wasserstein barycenter problem to estimate the proportions :math:`\mathbf{h}` in the target domain.
+
+ The algorithm used for solving the problem is the Iterative Bregman projections algorithm
+ with two sets of marginal constraints related to the unknown vector :math:`\mathbf{h}` and uniform target distribution.
+
+ Parameters
+ ----------
+ Xs : list of K np.ndarray(nsk,d)
+ features of all source domains' samples
+ Ys : list of K np.ndarray(nsk,)
+ labels of all source domains' samples
+ Xt : np.ndarray (nt,d)
+ samples in the target domain
+ reg : float
+ Regularization term > 0
+ metric : string, optional (default="sqeuclidean")
+ The ground metric for the Wasserstein problem
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on relative change in the barycenter (>0)
+ log : bool, optional
+ record log if True
+ verbose : bool, optional (default=False)
+ Controls the verbosity of the optimization algorithm
+
+ Returns
+ -------
+ h : (C,) ndarray
+ proportion estimation in the target domain
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia
+ "Optimal transport for multi-source domain adaptation under target shift",
+ International Conference on Artificial Intelligence and Statistics (AISTATS), 2019.
+
+ '''
+ nbclasses = len(np.unique(Ys[0]))
+ nbdomains = len(Xs)
+
+ # log dictionary
+ if log:
+ log = {'niter': 0, 'err': [], 'M': [], 'D1': [], 'D2': [], 'gamma': []}
+
+ K = []
+ M = []
+ D1 = []
+ D2 = []
+
+ # For each source domain, build cost matrices M, Gibbs kernels K and corresponding matrices D_1 and D_2
+ for d in range(nbdomains):
+ dom = {}
+ nsk = Xs[d].shape[0] # get number of elements for this domain
+ dom['nbelem'] = nsk
+ classes = np.unique(Ys[d]) # get number of classes for this domain
+
+ # format classes to start from 0 for convenience
+ if np.min(classes) != 0:
+ Ys[d] = Ys[d] - np.min(classes)
+ classes = np.unique(Ys[d])
+
+ # build the corresponding D_1 and D_2 matrices
+ Dtmp1 = np.zeros((nbclasses, nsk))
+ Dtmp2 = np.zeros((nbclasses, nsk))
+
+ for c in classes:
+ nbelemperclass = np.sum(Ys[d] == c)
+ if nbelemperclass != 0:
+ Dtmp1[int(c), Ys[d] == c] = 1.
+ Dtmp2[int(c), Ys[d] == c] = 1. / (nbelemperclass)
+ D1.append(Dtmp1)
+ D2.append(Dtmp2)
+
+ # build the cost matrix and the Gibbs kernel
+ Mtmp = dist(Xs[d], Xt, metric=metric)
+ M.append(Mtmp)
+
+ Ktmp = np.empty(Mtmp.shape, dtype=Mtmp.dtype)
+ np.divide(Mtmp, -reg, out=Ktmp)
+ np.exp(Ktmp, out=Ktmp)
+ K.append(Ktmp)
+
+ # uniform target distribution
+ a = unif(np.shape(Xt)[0])
+
+ cpt = 0 # iterations count
+ err = 1
+ old_bary = np.ones((nbclasses))
+
+ while (err > stopThr and cpt < numItermax):
+
+ bary = np.zeros((nbclasses))
+
+ # update coupling matrices for marginal constraints w.r.t. uniform target distribution
+ for d in range(nbdomains):
+ K[d] = projC(K[d], a)
+ other = np.sum(K[d], axis=1)
+ bary = bary + np.log(np.dot(D1[d], other)) / nbdomains
+
+ bary = np.exp(bary)
+
+ # update coupling matrices for marginal constraints w.r.t. unknown proportions based on [Prop 4., 27]
+ for d in range(nbdomains):
+ new = np.dot(D2[d].T, bary)
+ K[d] = projR(K[d], new)
+
+ err = np.linalg.norm(bary - old_bary)
+ cpt = cpt + 1
+ old_bary = bary
+
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ bary = bary / np.sum(bary)
+
+ if log:
+ log['niter'] = cpt
+ log['M'] = M
+ log['D1'] = D1
+ log['D2'] = D2
+ log['gamma'] = K
+ return bary, log
+ else:
+ return bary
+
+
def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
numIterMax=10000, stopThr=1e-9, verbose=False,
log=False, **kwargs):
@@ -1593,7 +1752,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
return pi
-def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
+def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9,
+ verbose=False, log=False, **kwargs):
r'''
Solve the entropic regularization optimal transport problem from empirical
data and return the OT loss
@@ -1675,14 +1835,17 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
M = dist(X_s, X_t, metric=metric)
if log:
- sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
return sinkhorn_loss, log
else:
- sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
return sinkhorn_loss
-def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
+def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9,
+ verbose=False, log=False, **kwargs):
r'''
Compute the sinkhorn divergence loss from empirical data
@@ -1768,11 +1931,14 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
.. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018
'''
if log:
- sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+ sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax,
+ stopThr=1e-9, verbose=verbose, log=log, **kwargs)
- sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+ sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax,
+ stopThr=1e-9, verbose=verbose, log=log, **kwargs)
- sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+ sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax,
+ stopThr=1e-9, verbose=verbose, log=log, **kwargs)
sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b)
@@ -1787,11 +1953,14 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
return max(0, sinkhorn_div), log
else:
- sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+ sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9,
+ verbose=verbose, log=log, **kwargs)
- sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+ sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9,
+ verbose=verbose, log=log, **kwargs)
- sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+ sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9,
+ verbose=verbose, log=log, **kwargs)
sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b)
return max(0, sinkhorn_div)
@@ -1883,7 +2052,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
try:
import bottleneck
except ImportError:
- warnings.warn("Bottleneck module is not installed. Install it from https://pypi.org/project/Bottleneck/ for better performance.")
+ warnings.warn(
+ "Bottleneck module is not installed. Install it from https://pypi.org/project/Bottleneck/ for better performance.")
bottleneck = np
a = np.asarray(a, dtype=np.float64)
@@ -2019,8 +2189,9 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
bounds_u = [(max(a_I_min / ((nt - nt_budget) * epsilon + nt_budget * (b_J_max / (
ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget
- bounds_v = [(max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))),
- epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget
+ bounds_v = [(
+ max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))),
+ epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget
# pre-calculated constants for the objective
vec_eps_IJc = epsilon * kappa * (K_IJc * np.ones(nt - nt_budget).reshape((1, -1))).sum(axis=1)
@@ -2069,7 +2240,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
return usc, vsc
def screened_obj(usc, vsc):
- part_IJ = np.dot(np.dot(usc, K_IJ), vsc) - kappa * np.dot(a_I, np.log(usc)) - (1. / kappa) * np.dot(b_J, np.log(vsc))
+ part_IJ = np.dot(np.dot(usc, K_IJ), vsc) - kappa * np.dot(a_I, np.log(usc)) - (1. / kappa) * np.dot(b_J,
+ np.log(vsc))
part_IJc = np.dot(usc, vec_eps_IJc)
part_IcJ = np.dot(vec_eps_IcJ, vsc)
psi_epsilon = part_IJ + part_IJc + part_IcJ
@@ -2091,9 +2263,9 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
g = np.hstack([g_u, g_v])
return f, g
- #----------------------------------------------------------------------------------------------------------------#
+ # ----------------------------------------------------------------------------------------------------------------#
# Step 2: L-BFGS-B solver #
- #----------------------------------------------------------------------------------------------------------------#
+ # ----------------------------------------------------------------------------------------------------------------#
u0, v0 = restricted_sinkhorn(u0, v0)
theta0 = np.hstack([u0, v0])