From 963a68163e922ac6731bddd1958861056b066d1e Mon Sep 17 00:00:00 2001 From: =?utf8?q?Rados=C5=82aw=20Zarzy=C5=84ski?= Date: Wed, 30 Aug 2023 15:23:34 +0200 Subject: [PATCH] test/test_weighted_shuffle: verify weights containing zeros MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Radosław Zarzyński (cherry picked from commit d02b17ff84c61123ed27d79dc177c2cfbbe6a72f) --- src/test/test_weighted_shuffle.cc | 52 +++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/src/test/test_weighted_shuffle.cc b/src/test/test_weighted_shuffle.cc index 9f92cbdc09519..efc1cdeb7cb18 100644 --- a/src/test/test_weighted_shuffle.cc +++ b/src/test/test_weighted_shuffle.cc @@ -37,3 +37,55 @@ TEST(WeightedShuffle, Basic) { epsilon); } } + +TEST(WeightedShuffle, ZeroedWeights) { + std::array choices{'a', 'b', 'c', 'd', 'e'}; + std::array weights{0, 0, 0, 0, 0}; + std::map> frequency { + {'a', {0, 0, 0, 0, 0}}, + {'b', {0, 0, 0, 0, 0}}, + {'c', {0, 0, 0, 0, 0}}, + {'d', {0, 0, 0, 0, 0}}, + {'e', {0, 0, 0, 0, 0}} + }; // count each element appearing in each position + const int samples = 10000; + std::random_device rd; + for (auto i = 0; i < samples; i++) { + weighted_shuffle(begin(choices), end(choices), + begin(weights), end(weights), + std::mt19937{rd()}); + for (size_t j = 0; j < choices.size(); ++j) + ++frequency[choices[j]][j]; + } + + for (char ch : choices) { + // all samples on the diagonal + ASSERT_EQ(std::accumulate(begin(frequency[ch]), end(frequency[ch]), 0), + samples); + ASSERT_EQ(frequency[ch][ch-'a'], samples); + } +} + +TEST(WeightedShuffle, SingleNonZeroWeight) { + std::array choices{'a', 'b', 'c', 'd', 'e'}; + std::array weights{0, 42, 0, 0, 0}; + std::map> frequency { + {'a', {0, 0, 0, 0, 0}}, + {'b', {0, 0, 0, 0, 0}}, + {'c', {0, 0, 0, 0, 0}}, + {'d', {0, 0, 0, 0, 0}}, + {'e', {0, 0, 0, 0, 0}} + }; // count each element appearing in each position + const int samples = 10000; + std::random_device rd; + for (auto i = 0; i < samples; i++) { + weighted_shuffle(begin(choices), end(choices), + begin(weights), end(weights), + std::mt19937{rd()}); + for (size_t j = 0; j < choices.size(); ++j) + ++frequency[choices[j]][j]; + } + + // 'b' is always first + ASSERT_EQ(frequency['b'][0], samples); +} -- 2.39.5