]> git.apps.os.sepia.ceph.com Git - ceph-ci.git/commitdiff
common/weighted_shufft: rewrite it to avoid copyright issue
authorKefu Chai <kchai@redhat.com>
Fri, 22 Mar 2019 13:28:26 +0000 (21:28 +0800)
committerKefu Chai <kchai@redhat.com>
Fri, 22 Mar 2019 14:10:26 +0000 (22:10 +0800)
Signed-off-by: Kefu Chai <kchai@redhat.com>
src/common/weighted_shuffle.h
src/mon/MonClient.cc
src/test/test_weighted_shuffle.cc

index 98c00c12112417427d31c6ae2cd4a7c725ed017f..10def0a011a41ecc8a6dd19ca3b88834e72ac2e2 100644 (file)
@@ -1,29 +1,25 @@
 // -*- mode:C++; tab-width:8; c-basic-offset:2; indent-tabs-mode:t -*-
 // vim: ts=8 sw=2 smarttab
-// https://stackoverflow.com/questions/50221136/c-weighted-stdshuffle/50223540
-#ifndef CEPH_WEIGHTED_SHUFFLE_H
-#define CEPH_WEIGHTED_SHUFFLE_H
+
+#pragma once
 
 #include <algorithm>
+#include <iterator>
 #include <random>
-#include <vector>
 
-template <class T, class U, class R>
-void weighted_shuffle(std::vector<T> &data, std::vector<U> &weights, R &&gen)
+template <class RandomIt, class DistIt, class URBG>
+void weighted_shuffle(RandomIt first, RandomIt last,
+                     DistIt weight_first, DistIt weight_last,
+                     URBG &&g)
 {
-  auto itd = data.begin();
-  auto itw = weights.begin();
-
-  while (itd != data.end() && itw != weights.end()) {
-    std::discrete_distribution d(itw, weights.end());
-    auto i = d(gen);
-    if (i) {
-      std::iter_swap(itd, std::next(itd, i));
-      std::iter_swap(itw, std::next(itw, i));
+  if (first == last) {
+    return;
+  } else {
+    std::discrete_distribution d{weight_first, weight_last};
+    if (auto n = d(g); n > 0) {
+      std::iter_swap(first, std::next(first, n));
+      std::iter_swap(weight_first, std::next(weight_first, n));
     }
-    ++itd;
-    ++itw;
+    weighted_shuffle(++first, last, ++weight_first, weight_last, std::move(g));
   }
 }
-
-#endif
index ef81f5356e89f7599c7fe7cbd672d6de0e864781..3c361ef59e263e60aaf44cd88fbea8fed9b97e2a 100644 (file)
@@ -730,8 +730,8 @@ void MonClient::_add_conns(uint64_t global_id)
       weights.push_back(monmap.get_weight(rank_name));
     }
     std::random_device rd;
-    std::mt19937 gen(rd());
-    weighted_shuffle(ranks, weights, gen);
+    weighted_shuffle(begin(ranks), end(ranks), begin(weights), end(weights),
+                    std::mt19937{rd()});
   }
   ldout(cct, 10) << __func__ << " ranks=" << ranks << dendl;
   unsigned n = cct->_conf->mon_client_hunt_parallel;
index 47d1438926cc220a13073fda8ffd9d7e8fdf8f74..7e881134a7f783a9688c886837371136c4d45d14 100644 (file)
@@ -1,20 +1,14 @@
 // -*- mode:C++; tab-width:8; c-basic-offset:2; indent-tabs-mode:t -*-
 // vim: ts=8 sw=2 smarttab
-// https://stackoverflow.com/questions/50221136/c-weighted-stdshuffle/50223540
 
 #include "common/weighted_shuffle.h"
 #include <map>
 #include "gtest/gtest.h"
 
 TEST(WeightedShuffle, Basic) {
-  std::random_device rd;
-  std::mt19937 gen(rd());
-
-  std::vector<char> data{'a', 'b', 'c', 'd', 'e'};
-  // NB: differences between weights should be significant
-  // otherwise test might fail
-  std::vector<int> weights{100, 50, 25, 10, 1};
-  std::map<char, std::vector<int>> frequency {
+  std::array<char, 5> choices{'a', 'b', 'c', 'd', 'e'};
+  std::array<int, 5> weights{100, 50, 25, 10, 1};
+  std::map<char, std::array<unsigned, 5>> frequency {
     {'a', {0, 0, 0, 0, 0}},
     {'b', {0, 0, 0, 0, 0}},
     {'c', {0, 0, 0, 0, 0}},
@@ -22,22 +16,23 @@ TEST(WeightedShuffle, Basic) {
     {'e', {0, 0, 0, 0, 0}}
   }; // count each element appearing in each position
   const int samples = 10000;
+  std::random_device rd;
   for (auto i = 0; i < samples; i++) {
-    weighted_shuffle(data, weights, gen);
-    for (size_t j = 0; j < data.size(); ++j)
-      ++frequency[data[j]][j];
+    weighted_shuffle(begin(choices), end(choices),
+                    begin(weights), end(weights),
+                    std::mt19937{rd()});
+    for (size_t j = 0; j < choices.size(); ++j)
+      ++frequency[choices[j]][j];
+  }
+  // verify that the probability that the nth choice is selected as the first
+  // one is the nth weight divided by the sum of all weights
+  const auto total_weight = std::accumulate(weights.begin(), weights.end(), 0);
+  constexpr float epsilon = 0.01;
+  for (unsigned i = 0; i < choices.size(); i++) {
+    const auto& f = frequency[choices[i]];
+    const auto& w = weights[i];
+    ASSERT_NEAR(float(w) / total_weight,
+               float(f.front()) / samples,
+               epsilon);
   }
-
-  // verify each element gets a chance to stay at the header of array
-  // e.g., because we don't want to starve anybody!
-  EXPECT_TRUE(std::all_of(frequency.begin(), frequency.end(), [](auto& it) {
-      return it.second.front() > 0;
-    }));
-
-  // verify the probability (of staying at the array header)
-  // are produced according to their corresponding weight
-  EXPECT_TRUE(std::is_sorted(frequency.begin(), frequency.end(),
-    [](auto& lhs, auto& rhs) {
-      return lhs.second.front() > rhs.second.front();
-    }));
 }