summaryrefslogtreecommitdiff
path: root/ot/optim.py
diff options
context:
space:
mode:
authortvayer <titouan.vayer@gmail.com>2019-05-29 17:05:38 +0200
committertvayer <titouan.vayer@gmail.com>2019-05-29 17:05:48 +0200
commite1bd94bb7e85a0d2fd0fcd7642b06da12c1db6db (patch)
tree1e85920b878ab715d211db56f99e25bfa2482fd3 /ot/optim.py
parentd4320382fa8873d15dcaec7adca3a4723c142515 (diff)
code review1
Diffstat (limited to 'ot/optim.py')
-rw-r--r--ot/optim.py31
1 files changed, 16 insertions, 15 deletions
diff --git a/ot/optim.py b/ot/optim.py
index 7d103e2..4d428d9 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -5,6 +5,7 @@ Optimization algorithms for OT
# Author: Remi Flamary <remi.flamary@unice.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
+#
# License: MIT License
import numpy as np
@@ -88,20 +89,20 @@ def do_linesearch(cost, G, deltaG, Mi, f_val,
Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost
f_val : float
Value of the cost at G
- armijo : bool, optionnal
+ armijo : bool, optional
If True the steps of the line-search is found via an armijo research. Else closed form is used.
If there is convergence issues use False.
- C1 : ndarray (ns,ns), optionnal
+ C1 : ndarray (ns,ns), optional
Structure matrix in the source domain. Only used when armijo=False
- C2 : ndarray (nt,nt), optionnal
+ C2 : ndarray (nt,nt), optional
Structure matrix in the target domain. Only used when armijo=False
- reg : float, optionnal
+ reg : float, optional
Regularization parameter. Only used when armijo=False
Gc : ndarray (ns,nt)
Optimal map found by linearization in the FW algorithm. Only used when armijo=False
constC : ndarray (ns,nt)
Constant for the gromov cost. See [24]. Only used when armijo=False
- M : ndarray (ns,nt), optionnal
+ M : ndarray (ns,nt), optional
Cost matrix between the features. Only used when armijo=False
Returns
-------
@@ -223,9 +224,9 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
it = 0
if verbose:
- print('{:5s}|{:12s}|{:8s}'.format(
- 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32)
- print('{:5d}|{:8e}|{:8e}'.format(it, f_val, 0))
+ print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
+ 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
+ print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0))
while loop:
@@ -261,8 +262,8 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
if verbose:
if it % 20 == 0:
- print('{:5s}|{:12s}|{:8s}'.format(
- 'It.', 'Loss', 'Relative variation loss', 'Absolute variation loss') + '\n' + '-' * 32)
+ print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
+ 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval))
if log:
@@ -363,9 +364,9 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
it = 0
if verbose:
- print('{:5s}|{:12s}|{:8s}'.format(
- 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32)
- print('{:5d}|{:8e}|{:8e}'.format(it, f_val, 0))
+ print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
+ 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
+ print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0))
while loop:
@@ -402,8 +403,8 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
if verbose:
if it % 20 == 0:
- print('{:5s}|{:12s}|{:8s}'.format(
- 'It.', 'Loss', 'Relative variation loss', 'Absolute variation loss') + '\n' + '-' * 32)
+ print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
+ 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval))
if log: