diff options
author | Arnur Nigmetov <a.nigmetov@gmail.com> | 2018-01-20 19:11:29 +0100 |
---|---|---|
committer | Arnur Nigmetov <a.nigmetov@gmail.com> | 2018-01-20 19:11:29 +0100 |
commit | 0cc35ad04f9c2997014d7cf62a12f697e79fb534 (patch) | |
tree | 744c07bc2c12fba193934ac98417c5063bead189 /geom_matching/wasserstein/include/dnn | |
parent | 3552ce68bc7654df35da471bd937b09a9fde101f (diff) |
Major rewrite, templatized version
Diffstat (limited to 'geom_matching/wasserstein/include/dnn')
8 files changed, 367 insertions, 56 deletions
diff --git a/geom_matching/wasserstein/include/dnn/geometry/euclidean-dynamic.h b/geom_matching/wasserstein/include/dnn/geometry/euclidean-dynamic.h new file mode 100644 index 0000000..4b98309 --- /dev/null +++ b/geom_matching/wasserstein/include/dnn/geometry/euclidean-dynamic.h @@ -0,0 +1,248 @@ +#ifndef DNN_GEOMETRY_EUCLIDEAN_DYNAMIC_H +#define DNN_GEOMETRY_EUCLIDEAN_DYNAMIC_H + +#include <vector> +#include <algorithm> +#include <boost/iterator/iterator_facade.hpp> +#include <boost/serialization/access.hpp> +#include <boost/serialization/vector.hpp> +#include <cmath> + +namespace hera +{ +namespace ws +{ +namespace dnn +{ + +template<class Real_> +class DynamicPointVector +{ + public: + using Real = Real_; + struct PointType + { + void* p; + + Real& operator[](const int i) + { + return (static_cast<Real*>(p))[i]; + } + + const Real& operator[](const int i) const + { + return (static_cast<Real*>(p))[i]; + } + + }; + struct iterator; + typedef iterator const_iterator; + + public: + DynamicPointVector(size_t point_capacity = 0): + point_capacity_(point_capacity) {} + + + PointType operator[](size_t i) const { return {(void*) &storage_[i*point_capacity_]}; } + inline void push_back(PointType p); + + inline iterator begin(); + inline iterator end(); + inline const_iterator begin() const; + inline const_iterator end() const; + + size_t size() const { return storage_.size() / point_capacity_; } + + void clear() { storage_.clear(); } + void swap(DynamicPointVector& other) { storage_.swap(other.storage_); std::swap(point_capacity_, other.point_capacity_); } + void reserve(size_t sz) { storage_.reserve(sz * point_capacity_); } + void resize(size_t sz) { storage_.resize(sz * point_capacity_); } + + private: + size_t point_capacity_; + std::vector<char> storage_; + + private: + friend class boost::serialization::access; + + template<class Archive> + void serialize(Archive& ar, const unsigned int version) { ar & point_capacity_ & storage_; } +}; + +template<typename Real> +struct DynamicPointTraits +{ + typedef DynamicPointVector<Real> PointContainer; + typedef typename PointContainer::PointType PointType; + struct PointHandle + { + void* p; + bool operator==(const PointHandle& other) const { return p == other.p; } + bool operator!=(const PointHandle& other) const { return !(*this == other); } + bool operator<(const PointHandle& other) const { return p < other.p; } + bool operator>(const PointHandle& other) const { return p > other.p; } + }; + + typedef Real Coordinate; + typedef Real DistanceType; + + DynamicPointTraits(unsigned dim = 0): + dim_(dim) {} + + DistanceType distance(PointType p1, PointType p2) const { return sqrt(sq_distance(p1,p2)); } + DistanceType distance(PointHandle p1, PointHandle p2) const { return distance(PointType({p1.p}), PointType({p2.p})); } + DistanceType sq_distance(PointType p1, PointType p2) const { Real res = 0; for (unsigned i = 0; i < dimension(); ++i) { Real c1 = coordinate(p1,i), c2 = coordinate(p2,i); res += (c1 - c2)*(c1 - c2); } return res; } + DistanceType sq_distance(PointHandle p1, PointHandle p2) const { return sq_distance(PointType({p1.p}), PointType({p2.p})); } + unsigned dimension() const { return dim_; } + Real& coordinate(PointType p, unsigned i) const { return ((Real*) p.p)[i]; } + Real& coordinate(PointHandle h, unsigned i) const { return ((Real*) h.p)[i]; } + + // it's non-standard to return a reference, but we can rely on it for code that assumes this particular point type + size_t& id(PointType p) const { return *((size_t*) ((Real*) p.p + dimension())); } + size_t& id(PointHandle h) const { return *((size_t*) ((Real*) h.p + dimension())); } + PointHandle handle(PointType p) const { return {p.p}; } + PointType point(PointHandle h) const { return {h.p}; } + + void swap(PointType p1, PointType p2) const { std::swap_ranges((char*) p1.p, ((char*) p1.p) + capacity(), (char*) p2.p); } + bool cmp(PointType p1, PointType p2) const { return std::lexicographical_compare((Real*) p1.p, ((Real*) p1.p) + dimension(), (Real*) p2.p, ((Real*) p2.p) + dimension()); } + bool eq(PointType p1, PointType p2) const { return std::equal((Real*) p1.p, ((Real*) p1.p) + dimension(), (Real*) p2.p); } + + // non-standard, and possibly a weird name + size_t capacity() const { return sizeof(Real)*dimension() + sizeof(size_t); } + + PointContainer container(size_t n = 0) const { PointContainer c(capacity()); c.resize(n); return c; } + PointContainer container(size_t n, const PointType& p) const; + + typename PointContainer::iterator + iterator(PointContainer& c, PointHandle ph) const; + typename PointContainer::const_iterator + iterator(const PointContainer& c, PointHandle ph) const; + + Real internal_p; + + private: + unsigned dim_; + + private: + friend class boost::serialization::access; + + template<class Archive> + void serialize(Archive& ar, const unsigned int version) { ar & dim_; } +}; + +} // dnn + +template<class Real> +struct dnn::DynamicPointVector<Real>::iterator: + public boost::iterator_facade<iterator, + PointType, + std::random_access_iterator_tag, + PointType, + std::ptrdiff_t> +{ + typedef boost::iterator_facade<iterator, + PointType, + std::random_access_iterator_tag, + PointType, + std::ptrdiff_t> Parent; + + + public: + typedef typename Parent::value_type value_type; + typedef typename Parent::difference_type difference_type; + typedef typename Parent::reference reference; + + iterator(size_t point_capacity = 0): + point_capacity_(point_capacity) {} + + iterator(void* p, size_t point_capacity): + p_(p), point_capacity_(point_capacity) {} + + private: + void increment() { p_ = ((char*) p_) + point_capacity_; } + void decrement() { p_ = ((char*) p_) - point_capacity_; } + void advance(difference_type n) { p_ = ((char*) p_) + n*point_capacity_; } + difference_type + distance_to(iterator other) const { return (((char*) other.p_) - ((char*) p_))/(int) point_capacity_; } + bool equal(const iterator& other) const { return p_ == other.p_; } + reference dereference() const { return {p_}; } + + friend class ::boost::iterator_core_access; + + private: + void* p_; + size_t point_capacity_; +}; + +template<class Real> +void dnn::DynamicPointVector<Real>::push_back(PointType p) +{ + if (storage_.capacity() < storage_.size() + point_capacity_) + storage_.reserve(1.5*storage_.capacity()); + + storage_.resize(storage_.size() + point_capacity_); + + std::copy((char*) p.p, (char*) p.p + point_capacity_, storage_.end() - point_capacity_); +} + +template<class Real> +typename dnn::DynamicPointVector<Real>::iterator dnn::DynamicPointVector<Real>::begin() { return iterator((void*) &*storage_.begin(), point_capacity_); } + +template<class Real> +typename dnn::DynamicPointVector<Real>::iterator dnn::DynamicPointVector<Real>::end() { return iterator((void*) &*storage_.end(), point_capacity_); } + +template<class Real> +typename dnn::DynamicPointVector<Real>::const_iterator dnn::DynamicPointVector<Real>::begin() const { return const_iterator((void*) &*storage_.begin(), point_capacity_); } + +template<class Real> +typename dnn::DynamicPointVector<Real>::const_iterator dnn::DynamicPointVector<Real>::end() const { return const_iterator((void*) &*storage_.end(), point_capacity_); } + +template<typename R> +typename dnn::DynamicPointTraits<R>::PointContainer +dnn::DynamicPointTraits<R>::container(size_t n, const PointType& p) const +{ + PointContainer c = container(n); + for (auto x : c) + std::copy((char*) p.p, (char*) p.p + capacity(), (char*) x.p); + return c; +} + +template<typename R> +typename dnn::DynamicPointTraits<R>::PointContainer::iterator +dnn::DynamicPointTraits<R>::iterator(PointContainer& c, PointHandle ph) const +{ return typename PointContainer::iterator(ph.p, capacity()); } + +template<typename R> +typename dnn::DynamicPointTraits<R>::PointContainer::const_iterator +dnn::DynamicPointTraits<R>::iterator(const PointContainer& c, PointHandle ph) const +{ return typename PointContainer::const_iterator(ph.p, capacity()); } + +} // ws +} // hera + +namespace std { + template<> + struct hash<typename hera::ws::dnn::DynamicPointTraits<double>::PointHandle> + { + using PointHandle = typename hera::ws::dnn::DynamicPointTraits<double>::PointHandle; + size_t operator()(const PointHandle& ph) const + { + return std::hash<void*>()(ph.p); + } + }; + + template<> + struct hash<typename hera::ws::dnn::DynamicPointTraits<float>::PointHandle> + { + using PointHandle = typename hera::ws::dnn::DynamicPointTraits<float>::PointHandle; + size_t operator()(const PointHandle& ph) const + { + return std::hash<void*>()(ph.p); + } + }; + + +} // std + + +#endif diff --git a/geom_matching/wasserstein/include/dnn/geometry/euclidean-fixed.h b/geom_matching/wasserstein/include/dnn/geometry/euclidean-fixed.h index e2c5b44..3e38baf 100644 --- a/geom_matching/wasserstein/include/dnn/geometry/euclidean-fixed.h +++ b/geom_matching/wasserstein/include/dnn/geometry/euclidean-fixed.h @@ -1,5 +1,5 @@ -#ifndef DNN_GEOMETRY_EUCLIDEAN_FIXED_H -#define DNN_GEOMETRY_EUCLIDEAN_FIXED_H +#ifndef HERA_WS_DNN_GEOMETRY_EUCLIDEAN_FIXED_H +#define HERA_WS_DNN_GEOMETRY_EUCLIDEAN_FIXED_H #include <boost/operators.hpp> #include <boost/array.hpp> @@ -15,6 +15,10 @@ #include "../parallel/tbb.h" // for dnn::vector<...> +namespace hera +{ +namespace ws +{ namespace dnn { // TODO: wrap in another namespace (e.g., euclidean) @@ -107,7 +111,7 @@ namespace dnn template<class Point> struct PointTraits; // intentionally undefined; should be specialized for each type - + template<size_t D, typename Real> struct PointTraits< Point<D, Real> > // specialization for dnn::Point { @@ -119,11 +123,11 @@ namespace dnn typedef typename PointType::DistanceType DistanceType; - static DistanceType - distance(const PointType& p1, const PointType& p2) { if (std::isinf(internal_p)) return p1.distance(p2); else return p1.p_distance(p2, internal_p); } + static DistanceType + distance(const PointType& p1, const PointType& p2) { if (hera::is_infinity(internal_p)) return p1.distance(p2); else return p1.p_distance(p2, internal_p); } - static DistanceType - distance(PointHandle p1, PointHandle p2) { return distance(*p1,*p2); } + static DistanceType + distance(PointHandle p1, PointHandle p2) { return distance(*p1,*p2); } static size_t dimension() { return D; } static Real coordinate(const PointType& p, size_t i) { return p[i]; } @@ -163,8 +167,8 @@ namespace dnn }; template<size_t D, typename Real> - Real PointTraits< Point<D, Real> >::internal_p = std::numeric_limits<Real>::infinity(); - + Real PointTraits< Point<D, Real> >::internal_p = hera::get_infinity<Real>(); + template<class PointContainer> void read_points(const std::string& filename, PointContainer& points) @@ -185,6 +189,8 @@ namespace dnn points.back()[i++] = x; } } -} +} // dnn +} // ws +} // hera #endif diff --git a/geom_matching/wasserstein/include/dnn/local/kd-tree.h b/geom_matching/wasserstein/include/dnn/local/kd-tree.h index 13eaf27..8e52a5c 100644 --- a/geom_matching/wasserstein/include/dnn/local/kd-tree.h +++ b/geom_matching/wasserstein/include/dnn/local/kd-tree.h @@ -1,5 +1,5 @@ -#ifndef DNN_LOCAL_KD_TREE_H -#define DNN_LOCAL_KD_TREE_H +#ifndef HERA_WS_DNN_LOCAL_KD_TREE_H +#define HERA_WS_DNN_LOCAL_KD_TREE_H #include "../utils.h" #include "search-functors.h" @@ -13,6 +13,10 @@ #include <boost/static_assert.hpp> #include <boost/type_traits.hpp> +namespace hera +{ +namespace ws +{ namespace dnn { // Weighted KDTree @@ -48,8 +52,9 @@ namespace dnn template<class Range> void init(const Range& range); - DistanceType weight(PointHandle p) { return weights_[indices_[p]]; } + DistanceType weight(PointHandle p) { return weights_[indices_[p]]; } void change_weight(PointHandle p, DistanceType w); + void adjust_weights(DistanceType delta); // subtract delta from all weights HandleDistance find(PointHandle q) const; Result findR(PointHandle q, DistanceType r) const; // all neighbors within r @@ -83,7 +88,9 @@ namespace dnn HandleMap indices_; double wassersteinPower; }; -} +} // dnn +} // ws +} // hera #include "kd-tree.hpp" diff --git a/geom_matching/wasserstein/include/dnn/local/kd-tree.hpp b/geom_matching/wasserstein/include/dnn/local/kd-tree.hpp index 22108aa..3a4f0eb 100644 --- a/geom_matching/wasserstein/include/dnn/local/kd-tree.hpp +++ b/geom_matching/wasserstein/include/dnn/local/kd-tree.hpp @@ -9,14 +9,14 @@ #include "def_debug_ws.h" template<class T> -dnn::KDTree<T>:: +hera::ws::dnn::KDTree<T>:: KDTree(const Traits& traits, HandleContainer&& handles, double _wassersteinPower): traits_(traits), tree_(std::move(handles)), wassersteinPower(_wassersteinPower) { assert(wassersteinPower >= 1.0); init(); } template<class T> template<class Range> -dnn::KDTree<T>:: +hera::ws::dnn::KDTree<T>:: KDTree(const Traits& traits, const Range& range, double _wassersteinPower): traits_(traits), wassersteinPower(_wassersteinPower) { @@ -27,7 +27,7 @@ KDTree(const Traits& traits, const Range& range, double _wassersteinPower): template<class T> template<class Range> void -dnn::KDTree<T>:: +hera::ws::dnn::KDTree<T>:: init(const Range& range) { size_t sz = std::distance(std::begin(range), std::end(range)); @@ -41,7 +41,7 @@ init(const Range& range) template<class T> void -dnn::KDTree<T>:: +hera::ws::dnn::KDTree<T>:: init() { if (tree_.empty()) @@ -61,7 +61,7 @@ init() template<class T> struct -dnn::KDTree<T>::OrderTree +hera::ws::dnn::KDTree<T>::OrderTree { OrderTree(HCIterator b_, HCIterator e_, size_t i_, const Traits& traits_): b(b_), e(e_), i(i_), traits(traits_) {} @@ -114,7 +114,7 @@ dnn::KDTree<T>::OrderTree template<class T> template<class ResultsFunctor> void -dnn::KDTree<T>:: +hera::ws::dnn::KDTree<T>:: search(PointHandle q, ResultsFunctor& rf) const { typedef typename HandleContainer::const_iterator HCIterator; @@ -123,7 +123,7 @@ search(PointHandle q, ResultsFunctor& rf) const if (tree_.empty()) return; - DistanceType D = std::numeric_limits<DistanceType>::infinity(); + DistanceType D = std::numeric_limits<DistanceType>::max(); // TODO: use tbb::scalable_allocator for the queue std::queue<KDTreeNode> nodes; @@ -140,14 +140,16 @@ search(PointHandle q, ResultsFunctor& rf) const i = (i + 1) % traits().dimension(); HCIterator m = b + (e - b)/2; - DistanceType dist = pow(traits().distance(q, *m), wassersteinPower) + weights_[m - tree_.begin()]; + + DistanceType dist = (wassersteinPower == 1.0) ? traits().distance(q, *m) + weights_[m - tree_.begin()] : std::pow(traits().distance(q, *m), wassersteinPower) + weights_[m - tree_.begin()]; D = rf(*m, dist); // we are really searching w.r.t L_\infty ball; could prune better with an L_2 ball Coordinate diff = cmp.diff(q, *m); // diff returns signed distance - DistanceType diffToWasserPower = (diff > 0 ? 1.0 : -1.0) * pow(fabs(diff), wassersteinPower); + + DistanceType diffToWasserPower = (wassersteinPower == 1.0) ? diff : ((diff > 0 ? 1.0 : -1.0) * std::pow(fabs(diff), wassersteinPower)); size_t lm = m + 1 + (e - (m+1))/2 - tree_.begin(); if (e > m + 1 && diffToWasserPower - subtree_weights_[lm] >= -D) { @@ -163,7 +165,20 @@ search(PointHandle q, ResultsFunctor& rf) const template<class T> void -dnn::KDTree<T>:: +hera::ws::dnn::KDTree<T>:: +adjust_weights(DistanceType delta) +{ + for(auto& w : weights_) + w -= delta; + + for(auto& sw : subtree_weights_) + sw -= delta; +} + + +template<class T> +void +hera::ws::dnn::KDTree<T>:: change_weight(PointHandle p, DistanceType w) { size_t idx = indices_[p]; @@ -246,32 +261,32 @@ change_weight(PointHandle p, DistanceType w) } template<class T> -typename dnn::KDTree<T>::HandleDistance -dnn::KDTree<T>:: +typename hera::ws::dnn::KDTree<T>::HandleDistance +hera::ws::dnn::KDTree<T>:: find(PointHandle q) const { - dnn::NNRecord<HandleDistance> nn; + hera::ws::dnn::NNRecord<HandleDistance> nn; search(q, nn); return nn.result; } template<class T> -typename dnn::KDTree<T>::Result -dnn::KDTree<T>:: +typename hera::ws::dnn::KDTree<T>::Result +hera::ws::dnn::KDTree<T>:: findR(PointHandle q, DistanceType r) const { - dnn::rNNRecord<HandleDistance> rnn(r); + hera::ws::dnn::rNNRecord<HandleDistance> rnn(r); search(q, rnn); std::sort(rnn.result.begin(), rnn.result.end()); return rnn.result; } template<class T> -typename dnn::KDTree<T>::Result -dnn::KDTree<T>:: +typename hera::ws::dnn::KDTree<T>::Result +hera::ws::dnn::KDTree<T>:: findK(PointHandle q, size_t k) const { - dnn::kNNRecord<HandleDistance> knn(k); + hera::ws::dnn::kNNRecord<HandleDistance> knn(k); search(q, knn); std::sort(knn.result.begin(), knn.result.end()); return knn.result; @@ -279,7 +294,7 @@ findK(PointHandle q, size_t k) const template<class T> -struct dnn::KDTree<T>::CoordinateComparison +struct hera::ws::dnn::KDTree<T>::CoordinateComparison { CoordinateComparison(size_t i, const Traits& traits): i_(i), traits_(traits) {} @@ -297,7 +312,7 @@ struct dnn::KDTree<T>::CoordinateComparison template<class T> void -dnn::KDTree<T>:: +hera::ws::dnn::KDTree<T>:: printWeights(void) { #ifndef FOR_R_TDA diff --git a/geom_matching/wasserstein/include/dnn/local/search-functors.h b/geom_matching/wasserstein/include/dnn/local/search-functors.h index f257d0c..1419f22 100644 --- a/geom_matching/wasserstein/include/dnn/local/search-functors.h +++ b/geom_matching/wasserstein/include/dnn/local/search-functors.h @@ -1,8 +1,12 @@ -#ifndef DNN_LOCAL_SEARCH_FUNCTORS_H -#define DNN_LOCAL_SEARCH_FUNCTORS_H +#ifndef HERA_WS_DNN_LOCAL_SEARCH_FUNCTORS_H +#define HERA_WS_DNN_LOCAL_SEARCH_FUNCTORS_H #include <boost/range/algorithm/heap_algorithm.hpp> +namespace hera +{ +namespace ws +{ namespace dnn { @@ -28,7 +32,7 @@ struct NNRecord typedef typename HandleDistance::PointHandle PointHandle; typedef typename HandleDistance::DistanceType DistanceType; - NNRecord() { result.d = std::numeric_limits<DistanceType>::infinity(); } + NNRecord() { result.d = std::numeric_limits<DistanceType>::max(); } DistanceType operator()(PointHandle p, DistanceType d) { if (d < result.d) { result.p = p; result.d = d; } return result.d; } HandleDistance result; }; @@ -67,7 +71,7 @@ struct kNNRecord result.push_back(HandleDistance(p,d)); boost::push_heap(result); if (result.size() < k) - return std::numeric_limits<DistanceType>::infinity(); + return std::numeric_limits<DistanceType>::max(); } else if (d < result[0].d) { boost::pop_heap(result); @@ -84,6 +88,8 @@ struct kNNRecord HDContainer result; }; -} +} // dnn +} // ws +} // hera #endif // DNN_LOCAL_SEARCH_FUNCTORS_H diff --git a/geom_matching/wasserstein/include/dnn/parallel/tbb.h b/geom_matching/wasserstein/include/dnn/parallel/tbb.h index 64c59e0..3f811d6 100644 --- a/geom_matching/wasserstein/include/dnn/parallel/tbb.h +++ b/geom_matching/wasserstein/include/dnn/parallel/tbb.h @@ -1,7 +1,6 @@ -#ifndef PARALLEL_H -#define PARALLEL_H +#ifndef HERA_WS_PARALLEL_H +#define HERA_WS_PARALLEL_H -//#include <iostream> #include <vector> #include <boost/range.hpp> @@ -18,6 +17,10 @@ #include <boost/serialization/collections_load_imp.hpp> #include <boost/serialization/collections_save_imp.hpp> +namespace hera +{ +namespace ws +{ namespace dnn { using tbb::mutex; @@ -87,7 +90,9 @@ namespace dnn tbb::tick_count start; }; -} +} // dnn +} // ws +} // hera // Serialization for tbb::concurrent_vector<...> namespace boost @@ -132,6 +137,10 @@ namespace boost #include <map> #include <boost/progress.hpp> +namespace hera +{ +namespace ws +{ namespace dnn { template<class T> @@ -207,14 +216,22 @@ namespace dnn }; using boost::progress_timer; -} +} // dnn +} // ws +} // hera #endif // TBB +namespace hera +{ +namespace ws +{ namespace dnn { template<class Range, class F> void do_foreach(const Range& range, const F& f) { do_foreach(boost::begin(range), boost::end(range), f); } -} +} // dnn +} // ws +} // hera #endif diff --git a/geom_matching/wasserstein/include/dnn/parallel/utils.h b/geom_matching/wasserstein/include/dnn/parallel/utils.h index ba73814..7104ec3 100644 --- a/geom_matching/wasserstein/include/dnn/parallel/utils.h +++ b/geom_matching/wasserstein/include/dnn/parallel/utils.h @@ -1,8 +1,12 @@ -#ifndef PARALLEL_UTILS_H -#define PARALLEL_UTILS_H +#ifndef HERA_WS_PARALLEL_UTILS_H +#define HERA_WS_PARALLEL_UTILS_H #include "../utils.h" +namespace hera +{ +namespace ws +{ namespace dnn { // Assumes rng is synchronized across ranks @@ -15,11 +19,13 @@ namespace dnn typedef decltype(data[0]) T; shuffle(world, data, rng, [](T& x, T& y) { std::swap(x,y); }); } -} +} // dnn +} // ws +} // hera template<class DataVector, class RNGType, class SwapFunctor> void -dnn::shuffle(mpi::communicator& world, DataVector& data, RNGType& rng, const SwapFunctor& swap, DataVector empty) +hera::ws::dnn::shuffle(mpi::communicator& world, DataVector& data, RNGType& rng, const SwapFunctor& swap, DataVector empty) { // This is not a perfect shuffle: it dishes out data in chunks of 1/size. // (It can be interpreted as generating a bistochastic matrix by taking the @@ -42,7 +48,7 @@ dnn::shuffle(mpi::communicator& world, DataVector& data, RNGType& rng, const Swa RNGType local_rng(seed); // Shuffle local data - dnn::random_shuffle(data.begin(), data.end(), local_rng, swap); + hera::ws::dnn::random_shuffle(data.begin(), data.end(), local_rng, swap); // Decide how much of our data goes to i-th processor std::vector<size_t> out_counts(size); @@ -50,7 +56,7 @@ dnn::shuffle(mpi::communicator& world, DataVector& data, RNGType& rng, const Swa boost::counting_iterator<int>(size)); for (size_t i = 0; i < size; ++i) { - dnn::random_shuffle(ranks.begin(), ranks.end(), rng); + hera::ws::dnn::random_shuffle(ranks.begin(), ranks.end(), rng); ++out_counts[ranks[rank]]; } @@ -87,7 +93,7 @@ dnn::shuffle(mpi::communicator& world, DataVector& data, RNGType& rng, const Swa for(const DataVector& vec : incoming) for (size_t i = 0; i < vec.size(); ++i) data.push_back(vec[i]); - dnn::random_shuffle(data.begin(), data.end(), local_rng, swap); + hera::ws::dnn::random_shuffle(data.begin(), data.end(), local_rng, swap); // XXX: the final shuffle is irrelevant for our purposes. But it's also cheap. } diff --git a/geom_matching/wasserstein/include/dnn/utils.h b/geom_matching/wasserstein/include/dnn/utils.h index 83c2865..bbce793 100644 --- a/geom_matching/wasserstein/include/dnn/utils.h +++ b/geom_matching/wasserstein/include/dnn/utils.h @@ -1,10 +1,14 @@ -#ifndef DNN_UTILS_H -#define DNN_UTILS_H +#ifndef HERA_WS_DNN_UTILS_H +#define HERA_WS_DNN_UTILS_H #include <boost/random/uniform_int.hpp> #include <boost/foreach.hpp> #include <boost/typeof/typeof.hpp> +namespace hera +{ +namespace ws +{ namespace dnn { @@ -36,6 +40,8 @@ void random_shuffle(RandomIt first, RandomIt last, UniformRandomNumberGenerator& random_shuffle(first, last, g, [](T& x, T& y) { std::swap(x,y); }); } -} +} // dnn +} // ws +} // hera #endif |