diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2022-12-15 09:28:01 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-12-15 09:28:01 +0100 |
commit | 0411ea22a96f9c22af30156b45c16ef39ffb520d (patch) | |
tree | 7c131ad804d5b16a8c362c2fe296350a770400df /test/test_utils.py | |
parent | 8490196dcc982c492b7565e1ec4de5f75f006acf (diff) |
[MRG] New API for OT solver (with pre-computed ground cost matrix) (#388)
* new API for OT solver
* use itertools for product of parameters
* add tests for result class
* add tests for result class
* add tests for result class last time?
* add sinkhorn
* make partial OT bckend compatible
* add TV as unbalanced flavor
* better tests
* make smoth backend compatible and add l2 tregularizatio to solve
* add reularizedd unbalanced
* add test for more complex attibutes
* add test for more complex attibutes
* add generic unbalaned solver and implement it for ot.solve
* add entropy to possible regularization
* star of documentation for ot.solv
* weird new pep8
* documenttaion for function ot.solve done
* pep8
* Update ot/solvers.py
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* update release file
* Apply suggestions from code review
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* add test NotImplemented
* pep8
* pep8gcmp pep8!
* compute kl in backend
* debug tensorflow kl backend
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'test/test_utils.py')
-rw-r--r-- | test/test_utils.py | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/test/test_utils.py b/test/test_utils.py index 19b6365..666c157 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -301,3 +301,32 @@ def test_BaseEstimator(): cl.set_params(bibi=10) assert cl.first == 'spam again' + + +def test_OTResult(): + + res = ot.utils.OTResult() + + # test print + print(res) + + # tets get citation + print(res.citation) + + lst_attributes = ['a_to_b', + 'b_to_a', + 'lazy_plan', + 'marginal_a', + 'marginal_b', + 'marginals', + 'plan', + 'potential_a', + 'potential_b', + 'potentials', + 'sparse_plan', + 'status', + 'value', + 'value_linear'] + for at in lst_attributes: + with pytest.raises(NotImplementedError): + getattr(res, at) |