From 6fdf5de8fa27fa16d6b8910fe96eb67b7761aa0e Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Wed, 21 Mar 2018 08:29:50 +0100 Subject: add linear mapping test + autopep8 --- examples/plot_otda_linear_mapping.py | 5 ++--- ot/bregman.py | 30 ++++++++++++++++++++---------- ot/lp/__init__.py | 3 ++- ot/optim.py | 9 ++++++--- ot/utils.py | 2 +- test/test_da.py | 18 ++++++++++++++++++ 6 files changed, 49 insertions(+), 18 deletions(-) diff --git a/examples/plot_otda_linear_mapping.py b/examples/plot_otda_linear_mapping.py index 165fe72..7a3b761 100644 --- a/examples/plot_otda_linear_mapping.py +++ b/examples/plot_otda_linear_mapping.py @@ -9,7 +9,6 @@ Created on Tue Mar 20 14:31:15 2018 import numpy as np import pylab as pl import ot -from scipy import ndimage ############################################################################## # Generate data @@ -87,8 +86,8 @@ def minmax(I): # Loading images -I1 = ndimage.imread('../data/ocean_day.jpg').astype(np.float64) / 256 -I2 = ndimage.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256 +I1 = pl.imread('../data/ocean_day.jpg').astype(np.float64) / 256 +I2 = pl.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256 X1 = im2mat(I1) diff --git a/ot/bregman.py b/ot/bregman.py index d63c51d..07b8660 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -11,7 +11,8 @@ Bregman projections for regularized OT import numpy as np -def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): +def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, + stopThr=1e-9, verbose=False, log=False, **kwargs): u""" Solve the entropic regularization optimal transport problem and return the OT matrix @@ -120,7 +121,8 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, ver return sink() -def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): +def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, + stopThr=1e-9, verbose=False, log=False, **kwargs): u""" Solve the entropic regularization optimal transport problem and return the loss @@ -233,7 +235,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, ve return sink() -def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): +def sinkhorn_knopp(a, b, M, reg, numItermax=1000, + stopThr=1e-9, verbose=False, log=False, **kwargs): """ Solve the entropic regularization optimal transport problem and return the OT matrix @@ -403,7 +406,8 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, l return u.reshape((-1, 1)) * K * v.reshape((1, -1)) -def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=20, log=False, **kwargs): +def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, + warmstart=None, verbose=False, print_period=20, log=False, **kwargs): """ Solve the entropic regularization OT problem with log stabilization @@ -526,11 +530,13 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, wa def get_K(alpha, beta): """log space computation""" - return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / reg) + return np.exp(-(M - alpha.reshape((na, 1)) - + beta.reshape((1, nb))) / reg) def get_Gamma(alpha, beta, u, v): """log space gamma computation""" - return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / reg + np.log(u.reshape((na, 1))) + np.log(v.reshape((1, nb)))) + return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / + reg + np.log(u.reshape((na, 1))) + np.log(v.reshape((1, nb)))) # print(np.min(K)) @@ -620,7 +626,8 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, wa return get_Gamma(alpha, beta, u, v) -def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInnerItermax=100, tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=10, log=False, **kwargs): +def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInnerItermax=100, + tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=10, log=False, **kwargs): """ Solve the entropic regularization optimal transport problem with log stabilization and epsilon scaling. @@ -739,7 +746,8 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne def get_K(alpha, beta): """log space computation""" - return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / reg) + return np.exp(-(M - alpha.reshape((na, 1)) - + beta.reshape((1, nb))) / reg) # print(np.min(K)) def get_reg(n): # exponential decreasing @@ -811,7 +819,8 @@ def projC(gamma, q): return np.multiply(gamma, q / np.maximum(np.sum(gamma, axis=0), 1e-10)) -def barycenter(A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False): +def barycenter(A, M, reg, weights=None, numItermax=1000, + stopThr=1e-4, verbose=False, log=False): """Compute the entropic regularized wasserstein barycenter of distributions A The function solves the following optimization problem: @@ -904,7 +913,8 @@ def barycenter(A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=F return geometricBar(weights, UKv) -def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, stopThr=1e-3, verbose=False, log=False): +def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, + stopThr=1e-3, verbose=False, log=False): """ Compute the unmixing of an observation with a given dictionary using Wasserstein distance diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 5c09da2..6371feb 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -107,7 +107,8 @@ def emd(a, b, M, numItermax=100000, log=False): return G -def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log=False, return_matrix=False): +def emd2(a, b, M, processes=multiprocessing.cpu_count(), + numItermax=100000, log=False, return_matrix=False): """Solves the Earth Movers distance problem and returns the loss .. math:: diff --git a/ot/optim.py b/ot/optim.py index 1d09adc..f31fae2 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -15,7 +15,8 @@ from .bregman import sinkhorn # The corresponding scipy function does not work for matrices -def line_search_armijo(f, xk, pk, gfk, old_fval, args=(), c1=1e-4, alpha0=0.99): +def line_search_armijo(f, xk, pk, gfk, old_fval, + args=(), c1=1e-4, alpha0=0.99): """ Armijo linesearch function that works with matrices @@ -71,7 +72,8 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, args=(), c1=1e-4, alpha0=0.99): return alpha, fc[0], phi1 -def cg(a, b, M, reg, f, df, G0=None, numItermax=200, stopThr=1e-9, verbose=False, log=False): +def cg(a, b, M, reg, f, df, G0=None, numItermax=200, + stopThr=1e-9, verbose=False, log=False): """ Solve the general regularized OT problem with conditional gradient @@ -202,7 +204,8 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, stopThr=1e-9, verbose=False return G -def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, numInnerItermax=200, stopThr=1e-9, verbose=False, log=False): +def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, + numInnerItermax=200, stopThr=1e-9, verbose=False, log=False): """ Solve the general regularized OT problem with the generalized conditional gradient diff --git a/ot/utils.py b/ot/utils.py index 9eab3fc..16862ea 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -316,7 +316,7 @@ def _is_deprecated(func): closures = [] is_deprecated = ('deprecated' in ''.join([c.cell_contents for c in closures - if isinstance(c.cell_contents, str)])) + if isinstance(c.cell_contents, str)])) return is_deprecated diff --git a/test/test_da.py b/test/test_da.py index 593dc53..7b63daf 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -444,6 +444,24 @@ def test_mapping_transport_class(): assert len(otda.log_.keys()) != 0 +def test_linear_mapping(): + + ns = 150 + nt = 200 + + Xs, ys = get_data_classif('3gauss', ns) + Xt, yt = get_data_classif('3gauss2', nt) + + A, b = ot.da.OT_mapping_linear(Xs, Xt) + + Xst = Xs.dot(A) + b + + Ct = np.cov(Xt.T) + Cst = np.cov(Xst.T) + + np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) + + def test_otda(): n_samples = 150 # nb samples -- cgit v1.2.3