diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2017-07-26 15:25:53 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-07-26 15:25:53 +0200 |
commit | 7638d019b43e52d17600cac653939e7cd807478c (patch) | |
tree | a77441ddf844d953a3e797a3fab2a1ee3b85bf34 /ot | |
parent | 1cf304cee298e2752ce29c83e5201f593722c3af (diff) | |
parent | 838550ead9cc8a66d9b9c1212c5dda2457dc59a5 (diff) |
Merge pull request #19 from rflamary/pytest
Pytest with 89% coverage
Fixes #19
Diffstat (limited to 'ot')
-rw-r--r-- | ot/__init__.py | 6 | ||||
-rw-r--r-- | ot/bregman.py | 25 | ||||
-rw-r--r-- | ot/da.py | 10 | ||||
-rw-r--r-- | ot/datasets.py | 4 | ||||
-rw-r--r-- | ot/dr.py | 4 | ||||
-rw-r--r-- | ot/gpu/__init__.py | 5 | ||||
-rw-r--r-- | ot/gpu/bregman.py | 5 | ||||
-rw-r--r-- | ot/gpu/da.py | 8 | ||||
-rw-r--r-- | ot/lp/__init__.py | 4 | ||||
-rw-r--r-- | ot/lp/emd_wrap.pyx | 9 | ||||
-rw-r--r-- | ot/optim.py | 6 | ||||
-rw-r--r-- | ot/plot.py | 3 | ||||
-rw-r--r-- | ot/utils.py | 5 |
13 files changed, 78 insertions, 16 deletions
diff --git a/ot/__init__.py b/ot/__init__.py index a79a5ce..6d4c4c6 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -4,6 +4,10 @@ """ +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + # All submodules and packages from . import lp @@ -24,6 +28,6 @@ from .utils import dist, unif, tic, toc, toq __version__ = "0.3.1" -__all__ = ["emd", "emd2", "sinkhorn","sinkhorn2", "utils", 'datasets', +__all__ = ["emd", "emd2", "sinkhorn", "sinkhorn2", "utils", 'datasets', 'bregman', 'lp', 'plot', 'tic', 'toc', 'toq', 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim'] diff --git a/ot/bregman.py b/ot/bregman.py index fe10880..d63c51d 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -3,6 +3,11 @@ Bregman projections for regularized OT """ +# Author: Remi Flamary <remi.flamary@unice.fr> +# Nicolas Courty <ncourty@irisa.fr> +# +# License: MIT License + import numpy as np @@ -103,8 +108,9 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, ver stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() == 'sinkhorn_epsilon_scaling': def sink(): - return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, **kwargs) + return sinkhorn_epsilon_scaling( + a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, **kwargs) else: print('Warning : unknown method using classic Sinkhorn Knopp') @@ -211,8 +217,9 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, ve stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() == 'sinkhorn_epsilon_scaling': def sink(): - return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, **kwargs) + return sinkhorn_epsilon_scaling( + a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, **kwargs) else: print('Warning : unknown method using classic Sinkhorn Knopp') @@ -588,7 +595,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, wa cpt = cpt + 1 - #print('err=',err,' cpt=',cpt) + # print('err=',err,' cpt=',cpt) if log: log['logu'] = alpha / reg + np.log(u) log['logv'] = beta / reg + np.log(v) @@ -773,7 +780,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne loop = False cpt = cpt + 1 - #print('err=',err,' cpt=',cpt) + # print('err=',err,' cpt=',cpt) if log: log['alpha'] = alpha log['beta'] = beta @@ -960,16 +967,16 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, stopThr=1e-3, verb """ - #M = M/np.median(M) + # M = M/np.median(M) K = np.exp(-M / reg) - #M0 = M0/np.median(M0) + # M0 = M0/np.median(M0) K0 = np.exp(-M0 / reg0) old = h0 err = 1 cpt = 0 - #log = {'niter':0, 'all_err':[]} + # log = {'niter':0, 'all_err':[]} if log: log = {'err': []} @@ -3,6 +3,12 @@ Domain adaptation with optimal transport """ +# Author: Remi Flamary <remi.flamary@unice.fr> +# Nicolas Courty <ncourty@irisa.fr> +# Michael Perrot <michael.perrot@univ-st-etienne.fr> +# +# License: MIT License + import numpy as np from .bregman import sinkhorn from .lp import emd @@ -472,7 +478,7 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', sigm Kp[:ns, :ns] = K # ls regu - #K0 = K1.T.dot(K1)+eta*I + # K0 = K1.T.dot(K1)+eta*I # Kreg=I # RKHS regul @@ -484,7 +490,7 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', sigm I = np.eye(ns) # ls regul - #K0 = K1.T.dot(K1)+eta*I + # K0 = K1.T.dot(K1)+eta*I # Kreg=I # proper kernel ridge diff --git a/ot/datasets.py b/ot/datasets.py index 4371a23..e4fe118 100644 --- a/ot/datasets.py +++ b/ot/datasets.py @@ -2,6 +2,10 @@ Simple example datasets for OT """ +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + import numpy as np import scipy as sp @@ -3,6 +3,10 @@ Dimension reduction with optimal transport """ +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + from scipy import linalg import autograd.numpy as np from pymanopt.manifolds import Stiefel diff --git a/ot/gpu/__init__.py b/ot/gpu/__init__.py index 40b11c0..c8f9433 100644 --- a/ot/gpu/__init__.py +++ b/ot/gpu/__init__.py @@ -4,4 +4,9 @@ from . import bregman from . import da from .bregman import sinkhorn +# Author: Remi Flamary <remi.flamary@unice.fr> +# Leo Gautheron <https://github.com/aje> +# +# License: MIT License + __all__ = ["bregman", "da", "sinkhorn"] diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py index 2302f80..47939c4 100644 --- a/ot/gpu/bregman.py +++ b/ot/gpu/bregman.py @@ -3,6 +3,11 @@ Bregman projections for regularized OT with GPU """ +# Author: Remi Flamary <remi.flamary@unice.fr> +# Leo Gautheron <https://github.com/aje> +# +# License: MIT License + import numpy as np import cudamat diff --git a/ot/gpu/da.py b/ot/gpu/da.py index c66e755..05c580f 100644 --- a/ot/gpu/da.py +++ b/ot/gpu/da.py @@ -3,6 +3,14 @@ Domain adaptation with optimal transport with GPU implementation """ +# Author: Remi Flamary <remi.flamary@unice.fr> +# Nicolas Courty <ncourty@irisa.fr> +# Michael Perrot <michael.perrot@univ-st-etienne.fr> +# Leo Gautheron <https://github.com/aje> +# +# License: MIT License + + import numpy as np from ..utils import unif from ..da import OTDA diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index db3da78..6e0bdb8 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -3,6 +3,10 @@ Solvers for the original linear program OT problem """ +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + import numpy as np # import compiled emd from .emd_wrap import emd_c, emd2_c diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 46794ab..46c96c1 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -1,9 +1,12 @@ # -*- coding: utf-8 -*- """ -Created on Thu Sep 11 08:42:08 2014 - -@author: rflamary +Cython linker with C solver """ + +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + import numpy as np cimport numpy as np diff --git a/ot/optim.py b/ot/optim.py index adad95e..1d09adc 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -3,6 +3,10 @@ Optimization algorithms for OT """ +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + import numpy as np from scipy.optimize.linesearch import scalar_search_armijo from .lp import emd @@ -300,7 +304,7 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, numInnerItermax=200, Mi = M + reg2 * df(G) # solve linear program with Sinkhorn - #Gc = sinkhorn_stabilized(a,b, Mi, reg1, numItermax = numInnerItermax) + # Gc = sinkhorn_stabilized(a,b, Mi, reg1, numItermax = numInnerItermax) Gc = sinkhorn(a, b, Mi, reg1, numItermax=numInnerItermax) deltaG = Gc - G @@ -2,6 +2,9 @@ Functions for plotting OT matrices """ +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License import numpy as np import matplotlib.pylab as pl diff --git a/ot/utils.py b/ot/utils.py index 1dee932..2b2f8b3 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -2,6 +2,11 @@ """ Various function that can be usefull """ + +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + import multiprocessing from functools import reduce import time |