fast_uniform_bits.h 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. // Copyright 2017 The Abseil Authors.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // https://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #ifndef ABSL_RANDOM_INTERNAL_FAST_UNIFORM_BITS_H_
  15. #define ABSL_RANDOM_INTERNAL_FAST_UNIFORM_BITS_H_
  16. #include <cstddef>
  17. #include <cstdint>
  18. #include <limits>
  19. #include <type_traits>
  20. #include "absl/base/config.h"
  21. #include "absl/meta/type_traits.h"
  22. namespace absl {
  23. ABSL_NAMESPACE_BEGIN
  24. namespace random_internal {
  25. // Returns true if the input value is zero or a power of two. Useful for
  26. // determining if the range of output values in a URBG
  27. template <typename UIntType>
  28. constexpr bool IsPowerOfTwoOrZero(UIntType n) {
  29. return (n == 0) || ((n & (n - 1)) == 0);
  30. }
  31. // Computes the length of the range of values producible by the URBG, or returns
  32. // zero if that would encompass the entire range of representable values in
  33. // URBG::result_type.
  34. template <typename URBG>
  35. constexpr typename URBG::result_type RangeSize() {
  36. using result_type = typename URBG::result_type;
  37. static_assert((URBG::max)() != (URBG::min)(), "URBG range cannot be 0.");
  38. return ((URBG::max)() == (std::numeric_limits<result_type>::max)() &&
  39. (URBG::min)() == std::numeric_limits<result_type>::lowest())
  40. ? result_type{0}
  41. : ((URBG::max)() - (URBG::min)() + result_type{1});
  42. }
  43. // Computes the floor of the log. (i.e., std::floor(std::log2(N));
  44. template <typename UIntType>
  45. constexpr UIntType IntegerLog2(UIntType n) {
  46. return (n <= 1) ? 0 : 1 + IntegerLog2(n >> 1);
  47. }
  48. // Returns the number of bits of randomness returned through
  49. // `PowerOfTwoVariate(urbg)`.
  50. template <typename URBG>
  51. constexpr size_t NumBits() {
  52. return RangeSize<URBG>() == 0
  53. ? std::numeric_limits<typename URBG::result_type>::digits
  54. : IntegerLog2(RangeSize<URBG>());
  55. }
  56. // Given a shift value `n`, constructs a mask with exactly the low `n` bits set.
  57. // If `n == 0`, all bits are set.
  58. template <typename UIntType>
  59. constexpr UIntType MaskFromShift(size_t n) {
  60. return ((n % std::numeric_limits<UIntType>::digits) == 0)
  61. ? ~UIntType{0}
  62. : (UIntType{1} << n) - UIntType{1};
  63. }
  64. // Tags used to dispatch FastUniformBits::generate to the simple or more complex
  65. // entropy extraction algorithm.
  66. struct SimplifiedLoopTag {};
  67. struct RejectionLoopTag {};
  68. // FastUniformBits implements a fast path to acquire uniform independent bits
  69. // from a type which conforms to the [rand.req.urbg] concept.
  70. // Parameterized by:
  71. // `UIntType`: the result (output) type
  72. //
  73. // The std::independent_bits_engine [rand.adapt.ibits] adaptor can be
  74. // instantiated from an existing generator through a copy or a move. It does
  75. // not, however, facilitate the production of pseudorandom bits from an un-owned
  76. // generator that will outlive the std::independent_bits_engine instance.
  77. template <typename UIntType = uint64_t>
  78. class FastUniformBits {
  79. public:
  80. using result_type = UIntType;
  81. static constexpr result_type(min)() { return 0; }
  82. static constexpr result_type(max)() {
  83. return (std::numeric_limits<result_type>::max)();
  84. }
  85. template <typename URBG>
  86. result_type operator()(URBG& g); // NOLINT(runtime/references)
  87. private:
  88. static_assert(std::is_unsigned<UIntType>::value,
  89. "Class-template FastUniformBits<> must be parameterized using "
  90. "an unsigned type.");
  91. // Generate() generates a random value, dispatched on whether
  92. // the underlying URBG must use rejection sampling to generate a value,
  93. // or whether a simplified loop will suffice.
  94. template <typename URBG>
  95. result_type Generate(URBG& g, // NOLINT(runtime/references)
  96. SimplifiedLoopTag);
  97. template <typename URBG>
  98. result_type Generate(URBG& g, // NOLINT(runtime/references)
  99. RejectionLoopTag);
  100. };
  101. template <typename UIntType>
  102. template <typename URBG>
  103. typename FastUniformBits<UIntType>::result_type
  104. FastUniformBits<UIntType>::operator()(URBG& g) { // NOLINT(runtime/references)
  105. // kRangeMask is the mask used when sampling variates from the URBG when the
  106. // width of the URBG range is not a power of 2.
  107. // Y = (2 ^ kRange) - 1
  108. static_assert((URBG::max)() > (URBG::min)(),
  109. "URBG::max and URBG::min may not be equal.");
  110. using tag = absl::conditional_t<IsPowerOfTwoOrZero(RangeSize<URBG>()),
  111. SimplifiedLoopTag, RejectionLoopTag>;
  112. return Generate(g, tag{});
  113. }
  114. template <typename UIntType>
  115. template <typename URBG>
  116. typename FastUniformBits<UIntType>::result_type
  117. FastUniformBits<UIntType>::Generate(URBG& g, // NOLINT(runtime/references)
  118. SimplifiedLoopTag) {
  119. // The simplified version of FastUniformBits works only on URBGs that have
  120. // a range that is a power of 2. In this case we simply loop and shift without
  121. // attempting to balance the bits across calls.
  122. static_assert(IsPowerOfTwoOrZero(RangeSize<URBG>()),
  123. "incorrect Generate tag for URBG instance");
  124. static constexpr size_t kResultBits =
  125. std::numeric_limits<result_type>::digits;
  126. static constexpr size_t kUrbgBits = NumBits<URBG>();
  127. static constexpr size_t kIters =
  128. (kResultBits / kUrbgBits) + (kResultBits % kUrbgBits != 0);
  129. static constexpr size_t kShift = (kIters == 1) ? 0 : kUrbgBits;
  130. static constexpr auto kMin = (URBG::min)();
  131. result_type r = static_cast<result_type>(g() - kMin);
  132. for (size_t n = 1; n < kIters; ++n) {
  133. r = (r << kShift) + static_cast<result_type>(g() - kMin);
  134. }
  135. return r;
  136. }
  137. template <typename UIntType>
  138. template <typename URBG>
  139. typename FastUniformBits<UIntType>::result_type
  140. FastUniformBits<UIntType>::Generate(URBG& g, // NOLINT(runtime/references)
  141. RejectionLoopTag) {
  142. static_assert(!IsPowerOfTwoOrZero(RangeSize<URBG>()),
  143. "incorrect Generate tag for URBG instance");
  144. using urbg_result_type = typename URBG::result_type;
  145. // See [rand.adapt.ibits] for more details on the constants calculated below.
  146. //
  147. // It is preferable to use roughly the same number of bits from each generator
  148. // call, however this is only possible when the number of bits provided by the
  149. // URBG is a divisor of the number of bits in `result_type`. In all other
  150. // cases, the number of bits used cannot always be the same, but it can be
  151. // guaranteed to be off by at most 1. Thus we run two loops, one with a
  152. // smaller bit-width size (`kSmallWidth`) and one with a larger width size
  153. // (satisfying `kLargeWidth == kSmallWidth + 1`). The loops are run
  154. // `kSmallIters` and `kLargeIters` times respectively such
  155. // that
  156. //
  157. // `kResultBits == kSmallIters * kSmallBits
  158. // + kLargeIters * kLargeBits`
  159. //
  160. // where `kResultBits` is the total number of bits in `result_type`.
  161. //
  162. static constexpr size_t kResultBits =
  163. std::numeric_limits<result_type>::digits; // w
  164. static constexpr urbg_result_type kUrbgRange = RangeSize<URBG>(); // R
  165. static constexpr size_t kUrbgBits = NumBits<URBG>(); // m
  166. // compute the initial estimate of the bits used.
  167. // [rand.adapt.ibits] 2 (c)
  168. static constexpr size_t kA = // ceil(w/m)
  169. (kResultBits / kUrbgBits) + ((kResultBits % kUrbgBits) != 0); // n'
  170. static constexpr size_t kABits = kResultBits / kA; // w0'
  171. static constexpr urbg_result_type kARejection =
  172. ((kUrbgRange >> kABits) << kABits); // y0'
  173. // refine the selection to reduce the rejection frequency.
  174. static constexpr size_t kTotalIters =
  175. ((kUrbgRange - kARejection) <= (kARejection / kA)) ? kA : (kA + 1); // n
  176. // [rand.adapt.ibits] 2 (b)
  177. static constexpr size_t kSmallIters =
  178. kTotalIters - (kResultBits % kTotalIters); // n0
  179. static constexpr size_t kSmallBits = kResultBits / kTotalIters; // w0
  180. static constexpr urbg_result_type kSmallRejection =
  181. ((kUrbgRange >> kSmallBits) << kSmallBits); // y0
  182. static constexpr size_t kLargeBits = kSmallBits + 1; // w0+1
  183. static constexpr urbg_result_type kLargeRejection =
  184. ((kUrbgRange >> kLargeBits) << kLargeBits); // y1
  185. //
  186. // Because `kLargeBits == kSmallBits + 1`, it follows that
  187. //
  188. // `kResultBits == kSmallIters * kSmallBits + kLargeIters`
  189. //
  190. // and therefore
  191. //
  192. // `kLargeIters == kTotalWidth % kSmallWidth`
  193. //
  194. // Intuitively, each iteration with the large width accounts for one unit
  195. // of the remainder when `kTotalWidth` is divided by `kSmallWidth`. As
  196. // mentioned above, if the URBG width is a divisor of `kTotalWidth`, then
  197. // there would be no need for any large iterations (i.e., one loop would
  198. // suffice), and indeed, in this case, `kLargeIters` would be zero.
  199. static_assert(kResultBits == kSmallIters * kSmallBits +
  200. (kTotalIters - kSmallIters) * kLargeBits,
  201. "Error in looping constant calculations.");
  202. // The small shift is essentially small bits, but due to the potential
  203. // of generating a smaller result_type from a larger urbg type, the actual
  204. // shift might be 0.
  205. static constexpr size_t kSmallShift = kSmallBits % kResultBits;
  206. static constexpr auto kSmallMask =
  207. MaskFromShift<urbg_result_type>(kSmallShift);
  208. static constexpr size_t kLargeShift = kLargeBits % kResultBits;
  209. static constexpr auto kLargeMask =
  210. MaskFromShift<urbg_result_type>(kLargeShift);
  211. static constexpr auto kMin = (URBG::min)();
  212. result_type s = 0;
  213. for (size_t n = 0; n < kSmallIters; ++n) {
  214. urbg_result_type v;
  215. do {
  216. v = g() - kMin;
  217. } while (v >= kSmallRejection);
  218. s = (s << kSmallShift) + static_cast<result_type>(v & kSmallMask);
  219. }
  220. for (size_t n = kSmallIters; n < kTotalIters; ++n) {
  221. urbg_result_type v;
  222. do {
  223. v = g() - kMin;
  224. } while (v >= kLargeRejection);
  225. s = (s << kLargeShift) + static_cast<result_type>(v & kLargeMask);
  226. }
  227. return s;
  228. }
  229. } // namespace random_internal
  230. ABSL_NAMESPACE_END
  231. } // namespace absl
  232. #endif // ABSL_RANDOM_INTERNAL_FAST_UNIFORM_BITS_H_