summaryrefslogtreecommitdiff
path: root/ot/smooth.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-05-31 11:26:32 +0200
committerRémi Flamary <remi.flamary@gmail.com>2018-05-31 11:26:32 +0200
commitaf99d3ff66062811a81454ea03e6b831a1292ae4 (patch)
tree6ade5d0ec2360bb3630635f3fdfad93b758a6323 /ot/smooth.py
parent8d288976d398749c5261daca22ee4d772bd5b489 (diff)
add smooth.py + first tests
Diffstat (limited to 'ot/smooth.py')
-rw-r--r--ot/smooth.py13
1 files changed, 13 insertions, 0 deletions
diff --git a/ot/smooth.py b/ot/smooth.py
index b8c0a9a..f8bb20a 100644
--- a/ot/smooth.py
+++ b/ot/smooth.py
@@ -85,10 +85,12 @@ class Regularization(object):
"""
Compute delta_Omega(X[:, j]) for each X[:, j].
delta_Omega(x) = sup_{y >= 0} y^T x - Omega(y).
+
Parameters
----------
X: array, shape = len(a) x len(b)
Input array.
+
Returns
-------
v: array, len(b)
@@ -107,6 +109,7 @@ class Regularization(object):
----------
X: array, shape = len(a) x len(b)
Input array.
+
Returns
-------
v: array, len(b)
@@ -124,6 +127,7 @@ class Regularization(object):
----------
T: array, shape = len(a) x len(b)
Input array.
+
Returns
-------
value: float
@@ -232,6 +236,7 @@ def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500):
Tolerance parameter.
max_iter: int
Maximum number of iterations.
+
Returns
-------
alpha: array, shape = len(a)
@@ -270,6 +275,7 @@ def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500):
def semi_dual_obj_grad(alpha, a, b, C, regul):
"""
Compute objective value and gradient of semi-dual objective.
+
Parameters
----------
alpha: array, shape = len(a)
@@ -281,6 +287,7 @@ def semi_dual_obj_grad(alpha, a, b, C, regul):
Ground cost matrix.
regul: Regularization object
Should implement a max_Omega(X) method.
+
Returns
-------
obj: float
@@ -307,6 +314,7 @@ def semi_dual_obj_grad(alpha, a, b, C, regul):
def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500):
"""
Solve the "smoothed" semi-dual objective.
+
Parameters
----------
a: array, shape = len(a)
@@ -322,6 +330,7 @@ def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500):
Tolerance parameter.
max_iter: int
Maximum number of iterations.
+
Returns
-------
alpha: array, shape = len(a)
@@ -344,6 +353,7 @@ def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500):
def get_plan_from_dual(alpha, beta, C, regul):
"""
Retrieve optimal transportation plan from optimal dual potentials.
+
Parameters
----------
alpha: array, shape = len(a)
@@ -353,6 +363,7 @@ def get_plan_from_dual(alpha, beta, C, regul):
Ground cost matrix.
regul: Regularization object
Should implement a delta_Omega(X) method.
+
Returns
-------
T: array, shape = len(a) x len(b)
@@ -365,6 +376,7 @@ def get_plan_from_dual(alpha, beta, C, regul):
def get_plan_from_semi_dual(alpha, b, C, regul):
"""
Retrieve optimal transportation plan from optimal semi-dual potentials.
+
Parameters
----------
alpha: array, shape = len(a)
@@ -375,6 +387,7 @@ def get_plan_from_semi_dual(alpha, b, C, regul):
Ground cost matrix.
regul: Regularization object
Should implement a delta_Omega(X) method.
+
Returns
-------
T: array, shape = len(a) x len(b)