summaryrefslogtreecommitdiff
path: root/ot/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/utils.py')
-rw-r--r--ot/utils.py238
1 files changed, 235 insertions, 3 deletions
diff --git a/ot/utils.py b/ot/utils.py
index a23ce7e..3423a7e 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -15,7 +15,7 @@ from scipy.spatial.distance import cdist
import sys
import warnings
from inspect import signature
-from .backend import get_backend, Backend
+from .backend import get_backend, Backend, NumpyBackend
__time_tic_toc = time.time()
@@ -232,9 +232,11 @@ def dist(x1, x2=None, metric='sqeuclidean', p=2, w=None):
if not get_backend(x1, x2).__name__ == 'numpy':
raise NotImplementedError()
else:
- if metric.endswith("minkowski"):
+ if isinstance(metric, str) and metric.endswith("minkowski"):
return cdist(x1, x2, metric=metric, p=p, w=w)
- return cdist(x1, x2, metric=metric, w=w)
+ if w is not None:
+ return cdist(x1, x2, metric=metric, w=w)
+ return cdist(x1, x2, metric=metric)
def dist0(n, method='lin_square'):
@@ -373,6 +375,36 @@ def check_random_state(seed):
' instance'.format(seed))
+def get_coordinate_circle(x):
+ r"""For :math:`x\in S^1 \subset \mathbb{R}^2`, returns the coordinates in
+ turn (in [0,1[).
+
+ .. math::
+ u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi}
+
+ Parameters
+ ----------
+ x: ndarray, shape (n, 2)
+ Samples on the circle with ambient coordinates
+
+ Returns
+ -------
+ x_t: ndarray, shape (n,)
+ Coordinates on [0,1[
+
+ Examples
+ --------
+ >>> u = np.array([[0.2,0.5,0.8]]) * (2 * np.pi)
+ >>> x1, y1 = np.cos(u), np.sin(u)
+ >>> x = np.concatenate([x1, y1]).T
+ >>> get_coordinate_circle(x)
+ array([0.2, 0.5, 0.8])
+ """
+ nx = get_backend(x)
+ x_t = (nx.atan2(-x[:, 1], -x[:, 0]) + np.pi) / (2 * np.pi)
+ return x_t
+
+
class deprecated(object):
r"""Decorator to mark a function or class as deprecated.
@@ -609,3 +641,203 @@ class UndefinedParameter(Exception):
"""
pass
+
+
+class OTResult:
+ def __init__(self, potentials=None, value=None, value_linear=None, plan=None, log=None, backend=None, sparse_plan=None, lazy_plan=None, status=None):
+
+ self._potentials = potentials
+ self._value = value
+ self._value_linear = value_linear
+ self._plan = plan
+ self._log = log
+ self._sparse_plan = sparse_plan
+ self._lazy_plan = lazy_plan
+ self._backend = backend if backend is not None else NumpyBackend()
+ self._status = status
+
+ # I assume that other solvers may return directly
+ # some primal objects?
+ # In the code below, let's define the main quantities
+ # that may be of interest to users.
+ # An OT solver returns an object that inherits from OTResult
+ # (e.g. SinkhornOTResult) and implements the relevant
+ # methods (e.g. "plan" and "lazy_plan" but not "sparse_plan", etc.).
+ # log is a dictionary containing potential information about the solver
+
+ # Dual potentials --------------------------------------------
+
+ def __repr__(self):
+ s = 'OTResult('
+ if self._value is not None:
+ s += 'value={},'.format(self._value)
+ if self._value_linear is not None:
+ s += 'value_linear={},'.format(self._value_linear)
+ if self._plan is not None:
+ s += 'plan={}(shape={}),'.format(self._plan.__class__.__name__, self._plan.shape)
+
+ if s[-1] != '(':
+ s = s[:-1] + ')'
+ else:
+ s = s + ')'
+ return s
+
+ @property
+ def potentials(self):
+ """Dual potentials, i.e. Lagrange multipliers for the marginal constraints.
+
+ This pair of arrays has the same shape, numerical type
+ and properties as the input weights "a" and "b".
+ """
+ if self._potentials is not None:
+ return self._potentials
+ else:
+ raise NotImplementedError()
+
+ @property
+ def potential_a(self):
+ """First dual potential, associated to the "source" measure "a"."""
+ if self._potentials is not None:
+ return self._potentials[0]
+ else:
+ raise NotImplementedError()
+
+ @property
+ def potential_b(self):
+ """Second dual potential, associated to the "target" measure "b"."""
+ if self._potentials is not None:
+ return self._potentials[1]
+ else:
+ raise NotImplementedError()
+
+ # Transport plan -------------------------------------------
+ @property
+ def plan(self):
+ """Transport plan, encoded as a dense array."""
+ # N.B.: We may catch out-of-memory errors and suggest
+ # the use of lazy_plan or sparse_plan when appropriate.
+
+ if self._plan is not None:
+ return self._plan
+ else:
+ raise NotImplementedError()
+
+ @property
+ def sparse_plan(self):
+ """Transport plan, encoded as a sparse array."""
+ if self._sparse_plan is not None:
+ return self._sparse_plan
+ elif self._plan is not None:
+ return self._backend.tocsr(self._plan)
+ else:
+ raise NotImplementedError()
+
+ @property
+ def lazy_plan(self):
+ """Transport plan, encoded as a symbolic KeOps LazyTensor."""
+ raise NotImplementedError()
+
+ # Loss values --------------------------------
+
+ @property
+ def value(self):
+ """Full transport cost, including possible regularization terms."""
+ if self._value is not None:
+ return self._value
+ else:
+ raise NotImplementedError()
+
+ @property
+ def value_linear(self):
+ """The "minimal" transport cost, i.e. the product between the transport plan and the cost."""
+ if self._value_linear is not None:
+ return self._value_linear
+ else:
+ raise NotImplementedError()
+
+ # Marginal constraints -------------------------
+ @property
+ def marginals(self):
+ """Marginals of the transport plan: should be very close to "a" and "b"
+ for balanced OT."""
+ if self._plan is not None:
+ return self.marginal_a, self.marginal_b
+ else:
+ raise NotImplementedError()
+
+ @property
+ def marginal_a(self):
+ """First marginal of the transport plan, with the same shape as "a"."""
+ if self._plan is not None:
+ return self._backend.sum(self._plan, 1)
+ else:
+ raise NotImplementedError()
+
+ @property
+ def marginal_b(self):
+ """Second marginal of the transport plan, with the same shape as "b"."""
+ if self._plan is not None:
+ return self._backend.sum(self._plan, 0)
+ else:
+ raise NotImplementedError()
+
+ @property
+ def status(self):
+ """Optimization status of the solver."""
+ if self._status is not None:
+ return self._status
+ else:
+ raise NotImplementedError()
+
+ # Barycentric mappings -------------------------
+ # Return the displacement vectors as an array
+ # that has the same shape as "xa"/"xb" (for samples)
+ # or "a"/"b" * D (for images)?
+
+ @property
+ def a_to_b(self):
+ """Displacement vectors from the first to the second measure."""
+ raise NotImplementedError()
+
+ @property
+ def b_to_a(self):
+ """Displacement vectors from the second to the first measure."""
+ raise NotImplementedError()
+
+ # # Wasserstein barycenters ----------------------
+ # @property
+ # def masses(self):
+ # """Masses for the Wasserstein barycenter."""
+ # raise NotImplementedError()
+
+ # @property
+ # def samples(self):
+ # """Sample locations for the Wasserstein barycenter."""
+ # raise NotImplementedError()
+
+ # Miscellaneous --------------------------------
+
+ @property
+ def citation(self):
+ """Appropriate citation(s) for this result, in plain text and BibTex formats."""
+
+ # The string below refers to the POT library:
+ # successor methods may concatenate the relevant references
+ # to the original definitions, solvers and underlying numerical backends.
+ return """POT library:
+
+ POT Python Optimal Transport library, Journal of Machine Learning Research, 22(78):1−8, 2021.
+ Website: https://pythonot.github.io/
+ Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer;
+
+ @article{flamary2021pot,
+ author = {R{\'e}mi Flamary and Nicolas Courty and Alexandre Gramfort and Mokhtar Z. Alaya and Aur{\'e}lie Boisbunon and Stanislas Chambon and Laetitia Chapel and Adrien Corenflos and Kilian Fatras and Nemo Fournier and L{\'e}o Gautheron and Nathalie T.H. Gayraud and Hicham Janati and Alain Rakotomamonjy and Ievgen Redko and Antoine Rolet and Antony Schutz and Vivien Seguy and Danica J. Sutherland and Romain Tavenard and Alexander Tong and Titouan Vayer},
+ title = {{POT}: {Python} {Optimal} {Transport}},
+ journal = {Journal of Machine Learning Research},
+ year = {2021},
+ volume = {22},
+ number = {78},
+ pages = {1-8},
+ url = {http://jmlr.org/papers/v22/20-451.html}
+ }
+ """