// -*- mode:C++; tab-width:8; c-basic-offset:2; indent-tabs-mode:t -*-
// vim: ts=8 sw=2 smarttab
-// https://stackoverflow.com/questions/50221136/c-weighted-stdshuffle/50223540
-#ifndef CEPH_WEIGHTED_SHUFFLE_H
-#define CEPH_WEIGHTED_SHUFFLE_H
+
+#pragma once
#include <algorithm>
+#include <iterator>
#include <random>
-#include <vector>
-template <class T, class U, class R>
-void weighted_shuffle(std::vector<T> &data, std::vector<U> &weights, R &&gen)
+template <class RandomIt, class DistIt, class URBG>
+void weighted_shuffle(RandomIt first, RandomIt last,
+ DistIt weight_first, DistIt weight_last,
+ URBG &&g)
{
- auto itd = data.begin();
- auto itw = weights.begin();
-
- while (itd != data.end() && itw != weights.end()) {
- std::discrete_distribution d(itw, weights.end());
- auto i = d(gen);
- if (i) {
- std::iter_swap(itd, std::next(itd, i));
- std::iter_swap(itw, std::next(itw, i));
+ if (first == last) {
+ return;
+ } else {
+ std::discrete_distribution d{weight_first, weight_last};
+ if (auto n = d(g); n > 0) {
+ std::iter_swap(first, std::next(first, n));
+ std::iter_swap(weight_first, std::next(weight_first, n));
}
- ++itd;
- ++itw;
+ weighted_shuffle(++first, last, ++weight_first, weight_last, std::move(g));
}
}
-
-#endif
// -*- mode:C++; tab-width:8; c-basic-offset:2; indent-tabs-mode:t -*-
// vim: ts=8 sw=2 smarttab
-// https://stackoverflow.com/questions/50221136/c-weighted-stdshuffle/50223540
#include "common/weighted_shuffle.h"
#include <map>
#include "gtest/gtest.h"
TEST(WeightedShuffle, Basic) {
- std::random_device rd;
- std::mt19937 gen(rd());
-
- std::vector<char> data{'a', 'b', 'c', 'd', 'e'};
- // NB: differences between weights should be significant
- // otherwise test might fail
- std::vector<int> weights{100, 50, 25, 10, 1};
- std::map<char, std::vector<int>> frequency {
+ std::array<char, 5> choices{'a', 'b', 'c', 'd', 'e'};
+ std::array<int, 5> weights{100, 50, 25, 10, 1};
+ std::map<char, std::array<unsigned, 5>> frequency {
{'a', {0, 0, 0, 0, 0}},
{'b', {0, 0, 0, 0, 0}},
{'c', {0, 0, 0, 0, 0}},
{'e', {0, 0, 0, 0, 0}}
}; // count each element appearing in each position
const int samples = 10000;
+ std::random_device rd;
for (auto i = 0; i < samples; i++) {
- weighted_shuffle(data, weights, gen);
- for (size_t j = 0; j < data.size(); ++j)
- ++frequency[data[j]][j];
+ weighted_shuffle(begin(choices), end(choices),
+ begin(weights), end(weights),
+ std::mt19937{rd()});
+ for (size_t j = 0; j < choices.size(); ++j)
+ ++frequency[choices[j]][j];
+ }
+ // verify that the probability that the nth choice is selected as the first
+ // one is the nth weight divided by the sum of all weights
+ const auto total_weight = std::accumulate(weights.begin(), weights.end(), 0);
+ constexpr float epsilon = 0.01;
+ for (unsigned i = 0; i < choices.size(); i++) {
+ const auto& f = frequency[choices[i]];
+ const auto& w = weights[i];
+ ASSERT_NEAR(float(w) / total_weight,
+ float(f.front()) / samples,
+ epsilon);
}
-
- // verify each element gets a chance to stay at the header of array
- // e.g., because we don't want to starve anybody!
- EXPECT_TRUE(std::all_of(frequency.begin(), frequency.end(), [](auto& it) {
- return it.second.front() > 0;
- }));
-
- // verify the probability (of staying at the array header)
- // are produced according to their corresponding weight
- EXPECT_TRUE(std::is_sorted(frequency.begin(), frequency.end(),
- [](auto& lhs, auto& rhs) {
- return lhs.second.front() > rhs.second.front();
- }));
}