summaryrefslogtreecommitdiff
path: root/ot/lp/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r--ot/lp/__init__.py83
1 files changed, 65 insertions, 18 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index d9b6fa9..abf7fe0 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -225,6 +225,13 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
from all compatible backends. But the algorithm uses the C++ CPU backend
which can lead to copy overhead on GPU arrays.
+ .. note:: This function will cast the computed transport plan to the data type
+ of the provided input with the following priority: :math:`\mathbf{a}`,
+ then :math:`\mathbf{b}`, then :math:`\mathbf{M}` if marginals are not provided.
+ Casting to an integer tensor might result in a loss of precision.
+ If this behaviour is unwanted, please make sure to provide a
+ floating point input.
+
Uses the algorithm proposed in :ref:`[1] <references-emd>`.
Parameters
@@ -290,12 +297,16 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
a, b, M = list_to_array(a, b, M)
a0, b0, M0 = a, b, M
+ if len(a0) != 0:
+ type_as = a0
+ elif len(b0) != 0:
+ type_as = b0
+ else:
+ type_as = M0
nx = get_backend(M0, a0, b0)
# convert to numpy
- M = nx.to_numpy(M)
- a = nx.to_numpy(a)
- b = nx.to_numpy(b)
+ M, a, b = nx.to_numpy(M, a, b)
# ensure float64
a = np.asarray(a, dtype=np.float64)
@@ -330,15 +341,23 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
u, v = estimate_dual_null_weights(u, v, a, b, M)
result_code_string = check_result(result_code)
+ if not nx.is_floating_point(type_as):
+ warnings.warn(
+ "Input histogram consists of integer. The transport plan will be "
+ "casted accordingly, possibly resulting in a loss of precision. "
+ "If this behaviour is unwanted, please make sure your input "
+ "histogram consists of floating point elements.",
+ stacklevel=2
+ )
if log:
log = {}
log['cost'] = cost
- log['u'] = nx.from_numpy(u, type_as=a0)
- log['v'] = nx.from_numpy(v, type_as=b0)
+ log['u'] = nx.from_numpy(u, type_as=type_as)
+ log['v'] = nx.from_numpy(v, type_as=type_as)
log['warning'] = result_code_string
log['result_code'] = result_code
- return nx.from_numpy(G, type_as=M0), log
- return nx.from_numpy(G, type_as=M0)
+ return nx.from_numpy(G, type_as=type_as), log
+ return nx.from_numpy(G, type_as=type_as)
def emd2(a, b, M, processes=1,
@@ -364,6 +383,14 @@ def emd2(a, b, M, processes=1,
from all compatible backends. But the algorithm uses the C++ CPU backend
which can lead to copy overhead on GPU arrays.
+ .. note:: This function will cast the computed transport plan and
+ transportation loss to the data type of the provided input with the
+ following priority: :math:`\mathbf{a}`, then :math:`\mathbf{b}`,
+ then :math:`\mathbf{M}` if marginals are not provided.
+ Casting to an integer tensor might result in a loss of precision.
+ If this behaviour is unwanted, please make sure to provide a
+ floating point input.
+
Uses the algorithm proposed in :ref:`[1] <references-emd2>`.
Parameters
@@ -432,12 +459,16 @@ def emd2(a, b, M, processes=1,
a, b, M = list_to_array(a, b, M)
a0, b0, M0 = a, b, M
+ if len(a0) != 0:
+ type_as = a0
+ elif len(b0) != 0:
+ type_as = b0
+ else:
+ type_as = M0
nx = get_backend(M0, a0, b0)
# convert to numpy
- M = nx.to_numpy(M)
- a = nx.to_numpy(a)
- b = nx.to_numpy(b)
+ M, a, b = nx.to_numpy(M, a, b)
a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
@@ -470,14 +501,22 @@ def emd2(a, b, M, processes=1,
result_code_string = check_result(result_code)
log = {}
- G = nx.from_numpy(G, type_as=M0)
+ if not nx.is_floating_point(type_as):
+ warnings.warn(
+ "Input histogram consists of integer. The transport plan will be "
+ "casted accordingly, possibly resulting in a loss of precision. "
+ "If this behaviour is unwanted, please make sure your input "
+ "histogram consists of floating point elements.",
+ stacklevel=2
+ )
+ G = nx.from_numpy(G, type_as=type_as)
if return_matrix:
log['G'] = G
- log['u'] = nx.from_numpy(u, type_as=a0)
- log['v'] = nx.from_numpy(v, type_as=b0)
+ log['u'] = nx.from_numpy(u, type_as=type_as)
+ log['v'] = nx.from_numpy(v, type_as=type_as)
log['warning'] = result_code_string
log['result_code'] = result_code
- cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0),
+ cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as),
(a0, b0, M0), (log['u'], log['v'], G))
return [cost, log]
else:
@@ -491,10 +530,18 @@ def emd2(a, b, M, processes=1,
if np.any(~asel) or np.any(~bsel):
u, v = estimate_dual_null_weights(u, v, a, b, M)
- G = nx.from_numpy(G, type_as=M0)
- cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0),
- (a0, b0, M0), (nx.from_numpy(u, type_as=a0),
- nx.from_numpy(v, type_as=b0), G))
+ if not nx.is_floating_point(type_as):
+ warnings.warn(
+ "Input histogram consists of integer. The transport plan will be "
+ "casted accordingly, possibly resulting in a loss of precision. "
+ "If this behaviour is unwanted, please make sure your input "
+ "histogram consists of floating point elements.",
+ stacklevel=2
+ )
+ G = nx.from_numpy(G, type_as=type_as)
+ cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as),
+ (a0, b0, M0), (nx.from_numpy(u, type_as=type_as),
+ nx.from_numpy(v, type_as=type_as), G))
check_result(result_code)
return cost