diff options
-rw-r--r-- | ot/bregman.py | 72 | ||||
-rw-r--r-- | ot/datasets.py | 4 | ||||
-rw-r--r-- | ot/lp/__init__.py | 4 | ||||
-rw-r--r-- | ot/plot.py | 3 |
4 files changed, 49 insertions, 34 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index fb959e9..951d3ce 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -19,6 +19,7 @@ import warnings from .utils import unif, dist from scipy.optimize import fmin_l_bfgs_b + def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): r""" @@ -539,12 +540,12 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, old_v = v[i_2] v[i_2] = b[i_2] / (K[:, i_2].T.dot(u)) G[:, i_2] = u * K[:, i_2] * v[i_2] - #aviol = (G@one_m - a) - #aviol_2 = (G.T@one_n - b) + # aviol = (G@one_m - a) + # aviol_2 = (G.T@one_n - b) viol += (-old_v + v[i_2]) * K[:, i_2] * u viol_2[i_2] = v[i_2] * K[:, i_2].dot(u) - b[i_2] - #print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2))) + # print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2))) if stopThr_val <= stopThr: break @@ -715,7 +716,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, if np.abs(u).max() > tau or np.abs(v).max() > tau: if n_hists: alpha, beta = alpha + reg * \ - np.max(np.log(u), 1), beta + reg * np.max(np.log(v)) + np.max(np.log(u), 1), beta + reg * np.max(np.log(v)) else: alpha, beta = alpha + reg * np.log(u), beta + reg * np.log(v) if n_hists: @@ -940,7 +941,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, # the 10th iterations transp = G err = np.linalg.norm( - (np.sum(transp, axis=0) - b))**2 + np.linalg.norm((np.sum(transp, axis=1) - a))**2 + (np.sum(transp, axis=0) - b)) ** 2 + np.linalg.norm((np.sum(transp, axis=1) - a)) ** 2 if log: log['err'].append(err) @@ -966,7 +967,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, def geometricBar(weights, alldistribT): """return the weighted geometric mean of distributions""" - assert(len(weights) == alldistribT.shape[1]) + assert (len(weights) == alldistribT.shape[1]) return np.exp(np.dot(np.log(alldistribT), weights.T)) @@ -1108,7 +1109,7 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, if weights is None: weights = np.ones(A.shape[1]) / A.shape[1] else: - assert(len(weights) == A.shape[1]) + assert (len(weights) == A.shape[1]) if log: log = {'err': []} @@ -1206,7 +1207,7 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, if weights is None: weights = np.ones(n_hists) / n_hists else: - assert(len(weights) == A.shape[1]) + assert (len(weights) == A.shape[1]) if log: log = {'err': []} @@ -1334,7 +1335,7 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, if weights is None: weights = np.ones(A.shape[0]) / A.shape[0] else: - assert(len(weights) == A.shape[0]) + assert (len(weights) == A.shape[0]) if log: log = {'err': []} @@ -1350,11 +1351,11 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, # this is equivalent to blurring on horizontal then vertical directions t = np.linspace(0, 1, A.shape[1]) [Y, X] = np.meshgrid(t, t) - xi1 = np.exp(-(X - Y)**2 / reg) + xi1 = np.exp(-(X - Y) ** 2 / reg) t = np.linspace(0, 1, A.shape[2]) [Y, X] = np.meshgrid(t, t) - xi2 = np.exp(-(X - Y)**2 / reg) + xi2 = np.exp(-(X - Y) ** 2 / reg) def K(x): return np.dot(np.dot(xi1, x), xi2) @@ -1501,6 +1502,7 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, else: return np.sum(K0, axis=1) + def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, stopThr=1e-6, verbose=False, log=False, **kwargs): r'''Joint OT and proportion estimation for multi-source target shift as proposed in [27] @@ -1658,6 +1660,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, else: return couplings, bary + def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs): @@ -1749,7 +1752,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', return pi -def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs): +def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, + verbose=False, log=False, **kwargs): r''' Solve the entropic regularization optimal transport problem from empirical data and return the OT loss @@ -1831,14 +1835,17 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num M = dist(X_s, X_t, metric=metric) if log: - sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) + sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, + **kwargs) return sinkhorn_loss, log else: - sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) + sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, + **kwargs) return sinkhorn_loss -def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs): +def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, + verbose=False, log=False, **kwargs): r''' Compute the sinkhorn divergence loss from empirical data @@ -1924,11 +1931,14 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018 ''' if log: - sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, + stopThr=1e-9, verbose=verbose, log=log, **kwargs) - sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, + stopThr=1e-9, verbose=verbose, log=log, **kwargs) - sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, + stopThr=1e-9, verbose=verbose, log=log, **kwargs) sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b) @@ -1943,11 +1953,14 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli return max(0, sinkhorn_div), log else: - sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, + verbose=verbose, log=log, **kwargs) - sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, + verbose=verbose, log=log, **kwargs) - sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, + verbose=verbose, log=log, **kwargs) sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b) return max(0, sinkhorn_div) @@ -2039,7 +2052,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res try: import bottleneck except ImportError: - warnings.warn("Bottleneck module is not installed. Install it from https://pypi.org/project/Bottleneck/ for better performance.") + warnings.warn( + "Bottleneck module is not installed. Install it from https://pypi.org/project/Bottleneck/ for better performance.") bottleneck = np a = np.asarray(a, dtype=np.float64) @@ -2173,10 +2187,11 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res # box constraints in L-BFGS-B (see Proposition 1 in [26]) bounds_u = [(max(a_I_min / ((nt - nt_budget) * epsilon + nt_budget * (b_J_max / ( - ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget + ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget - bounds_v = [(max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))), - epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget + bounds_v = [( + max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))), + epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget # pre-calculated constants for the objective vec_eps_IJc = epsilon * kappa * (K_IJc * np.ones(nt - nt_budget).reshape((1, -1))).sum(axis=1) @@ -2225,7 +2240,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res return usc, vsc def screened_obj(usc, vsc): - part_IJ = np.dot(np.dot(usc, K_IJ), vsc) - kappa * np.dot(a_I, np.log(usc)) - (1. / kappa) * np.dot(b_J, np.log(vsc)) + part_IJ = np.dot(np.dot(usc, K_IJ), vsc) - kappa * np.dot(a_I, np.log(usc)) - (1. / kappa) * np.dot(b_J, + np.log(vsc)) part_IJc = np.dot(usc, vec_eps_IJc) part_IcJ = np.dot(vec_eps_IcJ, vsc) psi_epsilon = part_IJ + part_IJc + part_IcJ @@ -2247,9 +2263,9 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res g = np.hstack([g_u, g_v]) return f, g - #----------------------------------------------------------------------------------------------------------------# + # ----------------------------------------------------------------------------------------------------------------# # Step 2: L-BFGS-B solver # - #----------------------------------------------------------------------------------------------------------------# + # ----------------------------------------------------------------------------------------------------------------# u0, v0 = restricted_sinkhorn(u0, v0) theta0 = np.hstack([u0, v0]) diff --git a/ot/datasets.py b/ot/datasets.py index eea9f37..a1ca7b6 100644 --- a/ot/datasets.py +++ b/ot/datasets.py @@ -30,7 +30,7 @@ def make_1D_gauss(n, m, s): 1D histogram for a gaussian distribution """ x = np.arange(n, dtype=np.float64) - h = np.exp(-(x - m)**2 / (2 * s**2)) + h = np.exp(-(x - m) ** 2 / (2 * s ** 2)) return h / h.sum() @@ -80,7 +80,7 @@ def get_2D_samples_gauss(n, m, sigma, random_state=None): return make_2D_samples_gauss(n, m, sigma, random_state=None) -def make_data_classif(dataset, n, nz=.5, theta=0, p = .5, random_state=None, **kwargs): +def make_data_classif(dataset, n, nz=.5, theta=0, p=.5, random_state=None, **kwargs): """Dataset generation for classification problems Parameters diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index cdd505d..7eaa44a 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -2,8 +2,6 @@ """ Solvers for the original linear program OT problem - - """ # Author: Remi Flamary <remi.flamary@unice.fr> @@ -18,7 +16,7 @@ from scipy.sparse import coo_matrix from .import cvx # import compiled emd -from .emd_wrap import emd_c, check_result, emd_1d_sorted +#from .emd_wrap import emd_c, check_result, emd_1d_sorted from ..utils import parmap from .cvx import barycenter from ..utils import dist @@ -78,9 +78,10 @@ def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs): thr : float, optional threshold above which the line is drawn **kwargs : dict - paameters given to the plot functions (default color is black if + parameters given to the plot functions (default color is black if nothing given) """ + if ('color' not in kwargs) and ('c' not in kwargs): kwargs['color'] = 'k' mx = G.max() |