summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-03-21 08:29:50 +0100
committerRémi Flamary <remi.flamary@gmail.com>2018-03-21 08:29:50 +0100
commit6fdf5de8fa27fa16d6b8910fe96eb67b7761aa0e (patch)
tree55b8d6e463b131134b4aa9d4e3013bbb77811da6
parent287c659ad35f5036ba2687caf73009ef455c7239 (diff)
add linear mapping test + autopep8
-rw-r--r--examples/plot_otda_linear_mapping.py5
-rw-r--r--ot/bregman.py30
-rw-r--r--ot/lp/__init__.py3
-rw-r--r--ot/optim.py9
-rw-r--r--ot/utils.py2
-rw-r--r--test/test_da.py18
6 files changed, 49 insertions, 18 deletions
diff --git a/examples/plot_otda_linear_mapping.py b/examples/plot_otda_linear_mapping.py
index 165fe72..7a3b761 100644
--- a/examples/plot_otda_linear_mapping.py
+++ b/examples/plot_otda_linear_mapping.py
@@ -9,7 +9,6 @@ Created on Tue Mar 20 14:31:15 2018
import numpy as np
import pylab as pl
import ot
-from scipy import ndimage
##############################################################################
# Generate data
@@ -87,8 +86,8 @@ def minmax(I):
# Loading images
-I1 = ndimage.imread('../data/ocean_day.jpg').astype(np.float64) / 256
-I2 = ndimage.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256
+I1 = pl.imread('../data/ocean_day.jpg').astype(np.float64) / 256
+I2 = pl.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256
X1 = im2mat(I1)
diff --git a/ot/bregman.py b/ot/bregman.py
index d63c51d..07b8660 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -11,7 +11,8 @@ Bregman projections for regularized OT
import numpy as np
-def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs):
+def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
+ stopThr=1e-9, verbose=False, log=False, **kwargs):
u"""
Solve the entropic regularization optimal transport problem and return the OT matrix
@@ -120,7 +121,8 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, ver
return sink()
-def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs):
+def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
+ stopThr=1e-9, verbose=False, log=False, **kwargs):
u"""
Solve the entropic regularization optimal transport problem and return the loss
@@ -233,7 +235,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, ve
return sink()
-def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs):
+def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
+ stopThr=1e-9, verbose=False, log=False, **kwargs):
"""
Solve the entropic regularization optimal transport problem and return the OT matrix
@@ -403,7 +406,8 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, l
return u.reshape((-1, 1)) * K * v.reshape((1, -1))
-def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=20, log=False, **kwargs):
+def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
+ warmstart=None, verbose=False, print_period=20, log=False, **kwargs):
"""
Solve the entropic regularization OT problem with log stabilization
@@ -526,11 +530,13 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, wa
def get_K(alpha, beta):
"""log space computation"""
- return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / reg)
+ return np.exp(-(M - alpha.reshape((na, 1)) -
+ beta.reshape((1, nb))) / reg)
def get_Gamma(alpha, beta, u, v):
"""log space gamma computation"""
- return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / reg + np.log(u.reshape((na, 1))) + np.log(v.reshape((1, nb))))
+ return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) /
+ reg + np.log(u.reshape((na, 1))) + np.log(v.reshape((1, nb))))
# print(np.min(K))
@@ -620,7 +626,8 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, wa
return get_Gamma(alpha, beta, u, v)
-def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInnerItermax=100, tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=10, log=False, **kwargs):
+def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInnerItermax=100,
+ tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=10, log=False, **kwargs):
"""
Solve the entropic regularization optimal transport problem with log
stabilization and epsilon scaling.
@@ -739,7 +746,8 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
def get_K(alpha, beta):
"""log space computation"""
- return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / reg)
+ return np.exp(-(M - alpha.reshape((na, 1)) -
+ beta.reshape((1, nb))) / reg)
# print(np.min(K))
def get_reg(n): # exponential decreasing
@@ -811,7 +819,8 @@ def projC(gamma, q):
return np.multiply(gamma, q / np.maximum(np.sum(gamma, axis=0), 1e-10))
-def barycenter(A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False):
+def barycenter(A, M, reg, weights=None, numItermax=1000,
+ stopThr=1e-4, verbose=False, log=False):
"""Compute the entropic regularized wasserstein barycenter of distributions A
The function solves the following optimization problem:
@@ -904,7 +913,8 @@ def barycenter(A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=F
return geometricBar(weights, UKv)
-def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, stopThr=1e-3, verbose=False, log=False):
+def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
+ stopThr=1e-3, verbose=False, log=False):
"""
Compute the unmixing of an observation with a given dictionary using Wasserstein distance
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 5c09da2..6371feb 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -107,7 +107,8 @@ def emd(a, b, M, numItermax=100000, log=False):
return G
-def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log=False, return_matrix=False):
+def emd2(a, b, M, processes=multiprocessing.cpu_count(),
+ numItermax=100000, log=False, return_matrix=False):
"""Solves the Earth Movers distance problem and returns the loss
.. math::
diff --git a/ot/optim.py b/ot/optim.py
index 1d09adc..f31fae2 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -15,7 +15,8 @@ from .bregman import sinkhorn
# The corresponding scipy function does not work for matrices
-def line_search_armijo(f, xk, pk, gfk, old_fval, args=(), c1=1e-4, alpha0=0.99):
+def line_search_armijo(f, xk, pk, gfk, old_fval,
+ args=(), c1=1e-4, alpha0=0.99):
"""
Armijo linesearch function that works with matrices
@@ -71,7 +72,8 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, args=(), c1=1e-4, alpha0=0.99):
return alpha, fc[0], phi1
-def cg(a, b, M, reg, f, df, G0=None, numItermax=200, stopThr=1e-9, verbose=False, log=False):
+def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
+ stopThr=1e-9, verbose=False, log=False):
"""
Solve the general regularized OT problem with conditional gradient
@@ -202,7 +204,8 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, stopThr=1e-9, verbose=False
return G
-def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, numInnerItermax=200, stopThr=1e-9, verbose=False, log=False):
+def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
+ numInnerItermax=200, stopThr=1e-9, verbose=False, log=False):
"""
Solve the general regularized OT problem with the generalized conditional gradient
diff --git a/ot/utils.py b/ot/utils.py
index 9eab3fc..16862ea 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -316,7 +316,7 @@ def _is_deprecated(func):
closures = []
is_deprecated = ('deprecated' in ''.join([c.cell_contents
for c in closures
- if isinstance(c.cell_contents, str)]))
+ if isinstance(c.cell_contents, str)]))
return is_deprecated
diff --git a/test/test_da.py b/test/test_da.py
index 593dc53..7b63daf 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -444,6 +444,24 @@ def test_mapping_transport_class():
assert len(otda.log_.keys()) != 0
+def test_linear_mapping():
+
+ ns = 150
+ nt = 200
+
+ Xs, ys = get_data_classif('3gauss', ns)
+ Xt, yt = get_data_classif('3gauss2', nt)
+
+ A, b = ot.da.OT_mapping_linear(Xs, Xt)
+
+ Xst = Xs.dot(A) + b
+
+ Ct = np.cov(Xt.T)
+ Cst = np.cov(Xst.T)
+
+ np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
+
+
def test_otda():
n_samples = 150 # nb samples