From af99d3ff66062811a81454ea03e6b831a1292ae4 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Thu, 31 May 2018 11:26:32 +0200 Subject: add smooth.py + first tests --- ot/smooth.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) (limited to 'ot') diff --git a/ot/smooth.py b/ot/smooth.py index b8c0a9a..f8bb20a 100644 --- a/ot/smooth.py +++ b/ot/smooth.py @@ -85,10 +85,12 @@ class Regularization(object): """ Compute delta_Omega(X[:, j]) for each X[:, j]. delta_Omega(x) = sup_{y >= 0} y^T x - Omega(y). + Parameters ---------- X: array, shape = len(a) x len(b) Input array. + Returns ------- v: array, len(b) @@ -107,6 +109,7 @@ class Regularization(object): ---------- X: array, shape = len(a) x len(b) Input array. + Returns ------- v: array, len(b) @@ -124,6 +127,7 @@ class Regularization(object): ---------- T: array, shape = len(a) x len(b) Input array. + Returns ------- value: float @@ -232,6 +236,7 @@ def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500): Tolerance parameter. max_iter: int Maximum number of iterations. + Returns ------- alpha: array, shape = len(a) @@ -270,6 +275,7 @@ def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500): def semi_dual_obj_grad(alpha, a, b, C, regul): """ Compute objective value and gradient of semi-dual objective. + Parameters ---------- alpha: array, shape = len(a) @@ -281,6 +287,7 @@ def semi_dual_obj_grad(alpha, a, b, C, regul): Ground cost matrix. regul: Regularization object Should implement a max_Omega(X) method. + Returns ------- obj: float @@ -307,6 +314,7 @@ def semi_dual_obj_grad(alpha, a, b, C, regul): def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500): """ Solve the "smoothed" semi-dual objective. + Parameters ---------- a: array, shape = len(a) @@ -322,6 +330,7 @@ def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500): Tolerance parameter. max_iter: int Maximum number of iterations. + Returns ------- alpha: array, shape = len(a) @@ -344,6 +353,7 @@ def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500): def get_plan_from_dual(alpha, beta, C, regul): """ Retrieve optimal transportation plan from optimal dual potentials. + Parameters ---------- alpha: array, shape = len(a) @@ -353,6 +363,7 @@ def get_plan_from_dual(alpha, beta, C, regul): Ground cost matrix. regul: Regularization object Should implement a delta_Omega(X) method. + Returns ------- T: array, shape = len(a) x len(b) @@ -365,6 +376,7 @@ def get_plan_from_dual(alpha, beta, C, regul): def get_plan_from_semi_dual(alpha, b, C, regul): """ Retrieve optimal transportation plan from optimal semi-dual potentials. + Parameters ---------- alpha: array, shape = len(a) @@ -375,6 +387,7 @@ def get_plan_from_semi_dual(alpha, b, C, regul): Ground cost matrix. regul: Regularization object Should implement a delta_Omega(X) method. + Returns ------- T: array, shape = len(a) x len(b) -- cgit v1.2.3