From 0cb9addd8316e73dd6b5919d6631957c45f2c564 Mon Sep 17 00:00:00 2001 From: Kefu Chai Date: Fri, 22 Mar 2019 21:28:26 +0800 Subject: [PATCH] common/weighted_shufft: rewrite it to avoid copyright issue Signed-off-by: Kefu Chai --- src/common/weighted_shuffle.h | 34 +++++++++++------------ src/mon/MonClient.cc | 4 +-- src/test/test_weighted_shuffle.cc | 45 ++++++++++++++----------------- 3 files changed, 37 insertions(+), 46 deletions(-) diff --git a/src/common/weighted_shuffle.h b/src/common/weighted_shuffle.h index 98c00c1211241..10def0a011a41 100644 --- a/src/common/weighted_shuffle.h +++ b/src/common/weighted_shuffle.h @@ -1,29 +1,25 @@ // -*- mode:C++; tab-width:8; c-basic-offset:2; indent-tabs-mode:t -*- // vim: ts=8 sw=2 smarttab -// https://stackoverflow.com/questions/50221136/c-weighted-stdshuffle/50223540 -#ifndef CEPH_WEIGHTED_SHUFFLE_H -#define CEPH_WEIGHTED_SHUFFLE_H + +#pragma once #include +#include #include -#include -template -void weighted_shuffle(std::vector &data, std::vector &weights, R &&gen) +template +void weighted_shuffle(RandomIt first, RandomIt last, + DistIt weight_first, DistIt weight_last, + URBG &&g) { - auto itd = data.begin(); - auto itw = weights.begin(); - - while (itd != data.end() && itw != weights.end()) { - std::discrete_distribution d(itw, weights.end()); - auto i = d(gen); - if (i) { - std::iter_swap(itd, std::next(itd, i)); - std::iter_swap(itw, std::next(itw, i)); + if (first == last) { + return; + } else { + std::discrete_distribution d{weight_first, weight_last}; + if (auto n = d(g); n > 0) { + std::iter_swap(first, std::next(first, n)); + std::iter_swap(weight_first, std::next(weight_first, n)); } - ++itd; - ++itw; + weighted_shuffle(++first, last, ++weight_first, weight_last, std::move(g)); } } - -#endif diff --git a/src/mon/MonClient.cc b/src/mon/MonClient.cc index ef81f5356e89f..3c361ef59e263 100644 --- a/src/mon/MonClient.cc +++ b/src/mon/MonClient.cc @@ -730,8 +730,8 @@ void MonClient::_add_conns(uint64_t global_id) weights.push_back(monmap.get_weight(rank_name)); } std::random_device rd; - std::mt19937 gen(rd()); - weighted_shuffle(ranks, weights, gen); + weighted_shuffle(begin(ranks), end(ranks), begin(weights), end(weights), + std::mt19937{rd()}); } ldout(cct, 10) << __func__ << " ranks=" << ranks << dendl; unsigned n = cct->_conf->mon_client_hunt_parallel; diff --git a/src/test/test_weighted_shuffle.cc b/src/test/test_weighted_shuffle.cc index 47d1438926cc2..7e881134a7f78 100644 --- a/src/test/test_weighted_shuffle.cc +++ b/src/test/test_weighted_shuffle.cc @@ -1,20 +1,14 @@ // -*- mode:C++; tab-width:8; c-basic-offset:2; indent-tabs-mode:t -*- // vim: ts=8 sw=2 smarttab -// https://stackoverflow.com/questions/50221136/c-weighted-stdshuffle/50223540 #include "common/weighted_shuffle.h" #include #include "gtest/gtest.h" TEST(WeightedShuffle, Basic) { - std::random_device rd; - std::mt19937 gen(rd()); - - std::vector data{'a', 'b', 'c', 'd', 'e'}; - // NB: differences between weights should be significant - // otherwise test might fail - std::vector weights{100, 50, 25, 10, 1}; - std::map> frequency { + std::array choices{'a', 'b', 'c', 'd', 'e'}; + std::array weights{100, 50, 25, 10, 1}; + std::map> frequency { {'a', {0, 0, 0, 0, 0}}, {'b', {0, 0, 0, 0, 0}}, {'c', {0, 0, 0, 0, 0}}, @@ -22,22 +16,23 @@ TEST(WeightedShuffle, Basic) { {'e', {0, 0, 0, 0, 0}} }; // count each element appearing in each position const int samples = 10000; + std::random_device rd; for (auto i = 0; i < samples; i++) { - weighted_shuffle(data, weights, gen); - for (size_t j = 0; j < data.size(); ++j) - ++frequency[data[j]][j]; + weighted_shuffle(begin(choices), end(choices), + begin(weights), end(weights), + std::mt19937{rd()}); + for (size_t j = 0; j < choices.size(); ++j) + ++frequency[choices[j]][j]; + } + // verify that the probability that the nth choice is selected as the first + // one is the nth weight divided by the sum of all weights + const auto total_weight = std::accumulate(weights.begin(), weights.end(), 0); + constexpr float epsilon = 0.01; + for (unsigned i = 0; i < choices.size(); i++) { + const auto& f = frequency[choices[i]]; + const auto& w = weights[i]; + ASSERT_NEAR(float(w) / total_weight, + float(f.front()) / samples, + epsilon); } - - // verify each element gets a chance to stay at the header of array - // e.g., because we don't want to starve anybody! - EXPECT_TRUE(std::all_of(frequency.begin(), frequency.end(), [](auto& it) { - return it.second.front() > 0; - })); - - // verify the probability (of staying at the array header) - // are produced according to their corresponding weight - EXPECT_TRUE(std::is_sorted(frequency.begin(), frequency.end(), - [](auto& lhs, auto& rhs) { - return lhs.second.front() > rhs.second.front(); - })); } -- 2.39.5