summaryrefslogtreecommitdiff
path: root/geom_matching/wasserstein/include/dnn/parallel/utils.h
diff options
context:
space:
mode:
Diffstat (limited to 'geom_matching/wasserstein/include/dnn/parallel/utils.h')
-rw-r--r--geom_matching/wasserstein/include/dnn/parallel/utils.h20
1 files changed, 13 insertions, 7 deletions
diff --git a/geom_matching/wasserstein/include/dnn/parallel/utils.h b/geom_matching/wasserstein/include/dnn/parallel/utils.h
index ba73814..7104ec3 100644
--- a/geom_matching/wasserstein/include/dnn/parallel/utils.h
+++ b/geom_matching/wasserstein/include/dnn/parallel/utils.h
@@ -1,8 +1,12 @@
-#ifndef PARALLEL_UTILS_H
-#define PARALLEL_UTILS_H
+#ifndef HERA_WS_PARALLEL_UTILS_H
+#define HERA_WS_PARALLEL_UTILS_H
#include "../utils.h"
+namespace hera
+{
+namespace ws
+{
namespace dnn
{
// Assumes rng is synchronized across ranks
@@ -15,11 +19,13 @@ namespace dnn
typedef decltype(data[0]) T;
shuffle(world, data, rng, [](T& x, T& y) { std::swap(x,y); });
}
-}
+} // dnn
+} // ws
+} // hera
template<class DataVector, class RNGType, class SwapFunctor>
void
-dnn::shuffle(mpi::communicator& world, DataVector& data, RNGType& rng, const SwapFunctor& swap, DataVector empty)
+hera::ws::dnn::shuffle(mpi::communicator& world, DataVector& data, RNGType& rng, const SwapFunctor& swap, DataVector empty)
{
// This is not a perfect shuffle: it dishes out data in chunks of 1/size.
// (It can be interpreted as generating a bistochastic matrix by taking the
@@ -42,7 +48,7 @@ dnn::shuffle(mpi::communicator& world, DataVector& data, RNGType& rng, const Swa
RNGType local_rng(seed);
// Shuffle local data
- dnn::random_shuffle(data.begin(), data.end(), local_rng, swap);
+ hera::ws::dnn::random_shuffle(data.begin(), data.end(), local_rng, swap);
// Decide how much of our data goes to i-th processor
std::vector<size_t> out_counts(size);
@@ -50,7 +56,7 @@ dnn::shuffle(mpi::communicator& world, DataVector& data, RNGType& rng, const Swa
boost::counting_iterator<int>(size));
for (size_t i = 0; i < size; ++i)
{
- dnn::random_shuffle(ranks.begin(), ranks.end(), rng);
+ hera::ws::dnn::random_shuffle(ranks.begin(), ranks.end(), rng);
++out_counts[ranks[rank]];
}
@@ -87,7 +93,7 @@ dnn::shuffle(mpi::communicator& world, DataVector& data, RNGType& rng, const Swa
for(const DataVector& vec : incoming)
for (size_t i = 0; i < vec.size(); ++i)
data.push_back(vec[i]);
- dnn::random_shuffle(data.begin(), data.end(), local_rng, swap);
+ hera::ws::dnn::random_shuffle(data.begin(), data.end(), local_rng, swap);
// XXX: the final shuffle is irrelevant for our purposes. But it's also cheap.
}