summaryrefslogtreecommitdiff
path: root/test/test_optim.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-07-24 14:58:15 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-07-24 14:58:15 +0200
commit7d9c5e7ef81cfb1cd4725058c09a7f683ca03eef (patch)
tree22577f7f53ff8b4904e96c2943846582ff505a84 /test/test_optim.py
parentf8e822c48eff02a3d65fc83d09dc0471bc9555aa (diff)
add test optim
Diffstat (limited to 'test/test_optim.py')
-rw-r--r--test/test_optim.py65
1 files changed, 65 insertions, 0 deletions
diff --git a/test/test_optim.py b/test/test_optim.py
new file mode 100644
index 0000000..43cba7d
--- /dev/null
+++ b/test/test_optim.py
@@ -0,0 +1,65 @@
+
+
+import ot
+import numpy as np
+
+# import pytest
+
+
+def test_conditional_gradient():
+
+ n = 100 # nb bins
+
+ # bin positions
+ x = np.arange(n, dtype=np.float64)
+
+ # Gaussian distributions
+ a = ot.datasets.get_1D_gauss(n, m=20, s=5) # m= mean, s= std
+ b = ot.datasets.get_1D_gauss(n, m=60, s=10)
+
+ # loss matrix
+ M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
+ M /= M.max()
+
+ def f(G):
+ return 0.5 * np.sum(G**2)
+
+ def df(G):
+ return G
+
+ reg = 1e-1
+
+ G, log = ot.optim.cg(a, b, M, reg, f, df, verbose=True, log=True)
+
+ assert np.allclose(a, G.sum(1))
+ assert np.allclose(b, G.sum(0))
+
+
+def test_generalized_conditional_gradient():
+
+ n = 100 # nb bins
+
+ # bin positions
+ x = np.arange(n, dtype=np.float64)
+
+ # Gaussian distributions
+ a = ot.datasets.get_1D_gauss(n, m=20, s=5) # m= mean, s= std
+ b = ot.datasets.get_1D_gauss(n, m=60, s=10)
+
+ # loss matrix
+ M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
+ M /= M.max()
+
+ def f(G):
+ return 0.5 * np.sum(G**2)
+
+ def df(G):
+ return G
+
+ reg1 = 1e-3
+ reg2 = 1e-1
+
+ G, log = ot.optim.gcg(a, b, M, reg1, reg2, f, df, verbose=True, log=True)
+
+ assert np.allclose(a, G.sum(1), atol=1e-05)
+ assert np.allclose(b, G.sum(0), atol=1e-05)