summaryrefslogtreecommitdiff
path: root/test/test_utils.py
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2022-12-15 09:28:01 +0100
committerGitHub <noreply@github.com>2022-12-15 09:28:01 +0100
commit0411ea22a96f9c22af30156b45c16ef39ffb520d (patch)
tree7c131ad804d5b16a8c362c2fe296350a770400df /test/test_utils.py
parent8490196dcc982c492b7565e1ec4de5f75f006acf (diff)
[MRG] New API for OT solver (with pre-computed ground cost matrix) (#388)
* new API for OT solver * use itertools for product of parameters * add tests for result class * add tests for result class * add tests for result class last time? * add sinkhorn * make partial OT bckend compatible * add TV as unbalanced flavor * better tests * make smoth backend compatible and add l2 tregularizatio to solve * add reularizedd unbalanced * add test for more complex attibutes * add test for more complex attibutes * add generic unbalaned solver and implement it for ot.solve * add entropy to possible regularization * star of documentation for ot.solv * weird new pep8 * documenttaion for function ot.solve done * pep8 * Update ot/solvers.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * update release file * Apply suggestions from code review Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * add test NotImplemented * pep8 * pep8gcmp pep8! * compute kl in backend * debug tensorflow kl backend Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'test/test_utils.py')
-rw-r--r--test/test_utils.py29
1 files changed, 29 insertions, 0 deletions
diff --git a/test/test_utils.py b/test/test_utils.py
index 19b6365..666c157 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -301,3 +301,32 @@ def test_BaseEstimator():
cl.set_params(bibi=10)
assert cl.first == 'spam again'
+
+
+def test_OTResult():
+
+ res = ot.utils.OTResult()
+
+ # test print
+ print(res)
+
+ # tets get citation
+ print(res.citation)
+
+ lst_attributes = ['a_to_b',
+ 'b_to_a',
+ 'lazy_plan',
+ 'marginal_a',
+ 'marginal_b',
+ 'marginals',
+ 'plan',
+ 'potential_a',
+ 'potential_b',
+ 'potentials',
+ 'sparse_plan',
+ 'status',
+ 'value',
+ 'value_linear']
+ for at in lst_attributes:
+ with pytest.raises(NotImplementedError):
+ getattr(res, at)