summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2017-07-26 15:25:53 +0200
committerGitHub <noreply@github.com>2017-07-26 15:25:53 +0200
commit7638d019b43e52d17600cac653939e7cd807478c (patch)
treea77441ddf844d953a3e797a3fab2a1ee3b85bf34 /ot
parent1cf304cee298e2752ce29c83e5201f593722c3af (diff)
parent838550ead9cc8a66d9b9c1212c5dda2457dc59a5 (diff)
Merge pull request #19 from rflamary/pytest
Pytest with 89% coverage Fixes #19
Diffstat (limited to 'ot')
-rw-r--r--ot/__init__.py6
-rw-r--r--ot/bregman.py25
-rw-r--r--ot/da.py10
-rw-r--r--ot/datasets.py4
-rw-r--r--ot/dr.py4
-rw-r--r--ot/gpu/__init__.py5
-rw-r--r--ot/gpu/bregman.py5
-rw-r--r--ot/gpu/da.py8
-rw-r--r--ot/lp/__init__.py4
-rw-r--r--ot/lp/emd_wrap.pyx9
-rw-r--r--ot/optim.py6
-rw-r--r--ot/plot.py3
-rw-r--r--ot/utils.py5
13 files changed, 78 insertions, 16 deletions
diff --git a/ot/__init__.py b/ot/__init__.py
index a79a5ce..6d4c4c6 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -4,6 +4,10 @@
"""
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
# All submodules and packages
from . import lp
@@ -24,6 +28,6 @@ from .utils import dist, unif, tic, toc, toq
__version__ = "0.3.1"
-__all__ = ["emd", "emd2", "sinkhorn","sinkhorn2", "utils", 'datasets',
+__all__ = ["emd", "emd2", "sinkhorn", "sinkhorn2", "utils", 'datasets',
'bregman', 'lp', 'plot', 'tic', 'toc', 'toq',
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim']
diff --git a/ot/bregman.py b/ot/bregman.py
index fe10880..d63c51d 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -3,6 +3,11 @@
Bregman projections for regularized OT
"""
+# Author: Remi Flamary <remi.flamary@unice.fr>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
import numpy as np
@@ -103,8 +108,9 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, ver
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
elif method.lower() == 'sinkhorn_epsilon_scaling':
def sink():
- return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return sinkhorn_epsilon_scaling(
+ a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log, **kwargs)
else:
print('Warning : unknown method using classic Sinkhorn Knopp')
@@ -211,8 +217,9 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, ve
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
elif method.lower() == 'sinkhorn_epsilon_scaling':
def sink():
- return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return sinkhorn_epsilon_scaling(
+ a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log, **kwargs)
else:
print('Warning : unknown method using classic Sinkhorn Knopp')
@@ -588,7 +595,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, wa
cpt = cpt + 1
- #print('err=',err,' cpt=',cpt)
+ # print('err=',err,' cpt=',cpt)
if log:
log['logu'] = alpha / reg + np.log(u)
log['logv'] = beta / reg + np.log(v)
@@ -773,7 +780,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
loop = False
cpt = cpt + 1
- #print('err=',err,' cpt=',cpt)
+ # print('err=',err,' cpt=',cpt)
if log:
log['alpha'] = alpha
log['beta'] = beta
@@ -960,16 +967,16 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, stopThr=1e-3, verb
"""
- #M = M/np.median(M)
+ # M = M/np.median(M)
K = np.exp(-M / reg)
- #M0 = M0/np.median(M0)
+ # M0 = M0/np.median(M0)
K0 = np.exp(-M0 / reg0)
old = h0
err = 1
cpt = 0
- #log = {'niter':0, 'all_err':[]}
+ # log = {'niter':0, 'all_err':[]}
if log:
log = {'err': []}
diff --git a/ot/da.py b/ot/da.py
index 5039fbd..4f9bce5 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -3,6 +3,12 @@
Domain adaptation with optimal transport
"""
+# Author: Remi Flamary <remi.flamary@unice.fr>
+# Nicolas Courty <ncourty@irisa.fr>
+# Michael Perrot <michael.perrot@univ-st-etienne.fr>
+#
+# License: MIT License
+
import numpy as np
from .bregman import sinkhorn
from .lp import emd
@@ -472,7 +478,7 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', sigm
Kp[:ns, :ns] = K
# ls regu
- #K0 = K1.T.dot(K1)+eta*I
+ # K0 = K1.T.dot(K1)+eta*I
# Kreg=I
# RKHS regul
@@ -484,7 +490,7 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', sigm
I = np.eye(ns)
# ls regul
- #K0 = K1.T.dot(K1)+eta*I
+ # K0 = K1.T.dot(K1)+eta*I
# Kreg=I
# proper kernel ridge
diff --git a/ot/datasets.py b/ot/datasets.py
index 4371a23..e4fe118 100644
--- a/ot/datasets.py
+++ b/ot/datasets.py
@@ -2,6 +2,10 @@
Simple example datasets for OT
"""
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
import numpy as np
import scipy as sp
diff --git a/ot/dr.py b/ot/dr.py
index 77cbae2..d30ab30 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -3,6 +3,10 @@
Dimension reduction with optimal transport
"""
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
from scipy import linalg
import autograd.numpy as np
from pymanopt.manifolds import Stiefel
diff --git a/ot/gpu/__init__.py b/ot/gpu/__init__.py
index 40b11c0..c8f9433 100644
--- a/ot/gpu/__init__.py
+++ b/ot/gpu/__init__.py
@@ -4,4 +4,9 @@ from . import bregman
from . import da
from .bregman import sinkhorn
+# Author: Remi Flamary <remi.flamary@unice.fr>
+# Leo Gautheron <https://github.com/aje>
+#
+# License: MIT License
+
__all__ = ["bregman", "da", "sinkhorn"]
diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py
index 2302f80..47939c4 100644
--- a/ot/gpu/bregman.py
+++ b/ot/gpu/bregman.py
@@ -3,6 +3,11 @@
Bregman projections for regularized OT with GPU
"""
+# Author: Remi Flamary <remi.flamary@unice.fr>
+# Leo Gautheron <https://github.com/aje>
+#
+# License: MIT License
+
import numpy as np
import cudamat
diff --git a/ot/gpu/da.py b/ot/gpu/da.py
index c66e755..05c580f 100644
--- a/ot/gpu/da.py
+++ b/ot/gpu/da.py
@@ -3,6 +3,14 @@
Domain adaptation with optimal transport with GPU implementation
"""
+# Author: Remi Flamary <remi.flamary@unice.fr>
+# Nicolas Courty <ncourty@irisa.fr>
+# Michael Perrot <michael.perrot@univ-st-etienne.fr>
+# Leo Gautheron <https://github.com/aje>
+#
+# License: MIT License
+
+
import numpy as np
from ..utils import unif
from ..da import OTDA
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index db3da78..6e0bdb8 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -3,6 +3,10 @@
Solvers for the original linear program OT problem
"""
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
import numpy as np
# import compiled emd
from .emd_wrap import emd_c, emd2_c
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index 46794ab..46c96c1 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -1,9 +1,12 @@
# -*- coding: utf-8 -*-
"""
-Created on Thu Sep 11 08:42:08 2014
-
-@author: rflamary
+Cython linker with C solver
"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
import numpy as np
cimport numpy as np
diff --git a/ot/optim.py b/ot/optim.py
index adad95e..1d09adc 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -3,6 +3,10 @@
Optimization algorithms for OT
"""
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
import numpy as np
from scipy.optimize.linesearch import scalar_search_armijo
from .lp import emd
@@ -300,7 +304,7 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, numInnerItermax=200,
Mi = M + reg2 * df(G)
# solve linear program with Sinkhorn
- #Gc = sinkhorn_stabilized(a,b, Mi, reg1, numItermax = numInnerItermax)
+ # Gc = sinkhorn_stabilized(a,b, Mi, reg1, numItermax = numInnerItermax)
Gc = sinkhorn(a, b, Mi, reg1, numItermax=numInnerItermax)
deltaG = Gc - G
diff --git a/ot/plot.py b/ot/plot.py
index 61afc9f..784a372 100644
--- a/ot/plot.py
+++ b/ot/plot.py
@@ -2,6 +2,9 @@
Functions for plotting OT matrices
"""
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
import numpy as np
import matplotlib.pylab as pl
diff --git a/ot/utils.py b/ot/utils.py
index 1dee932..2b2f8b3 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -2,6 +2,11 @@
"""
Various function that can be usefull
"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
import multiprocessing
from functools import reduce
import time