generators_test.cc 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. // Copyright 2017 The Abseil Authors.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // https://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include <cstddef>
  15. #include <cstdint>
  16. #include <random>
  17. #include <vector>
  18. #include "gtest/gtest.h"
  19. #include "absl/random/distributions.h"
  20. #include "absl/random/random.h"
  21. namespace {
  22. template <typename URBG>
  23. void TestUniform(URBG* gen) {
  24. // [a, b) default-semantics, inferred types.
  25. absl::Uniform(*gen, 0, 100); // int
  26. absl::Uniform(*gen, 0, 1.0); // Promoted to double
  27. absl::Uniform(*gen, 0.0f, 1.0); // Promoted to double
  28. absl::Uniform(*gen, 0.0, 1.0); // double
  29. absl::Uniform(*gen, -1, 1L); // Promoted to long
  30. // Roll a die.
  31. absl::Uniform(absl::IntervalClosedClosed, *gen, 1, 6);
  32. // Get a fraction.
  33. absl::Uniform(absl::IntervalOpenOpen, *gen, 0.0, 1.0);
  34. // Assign a value to a random element.
  35. std::vector<int> elems = {10, 20, 30, 40, 50};
  36. elems[absl::Uniform(*gen, 0u, elems.size())] = 5;
  37. elems[absl::Uniform<size_t>(*gen, 0, elems.size())] = 3;
  38. // Choose some epsilon around zero.
  39. absl::Uniform(absl::IntervalOpenOpen, *gen, -1.0, 1.0);
  40. // (a, b) semantics, inferred types.
  41. absl::Uniform(absl::IntervalOpenOpen, *gen, 0, 1.0); // Promoted to double
  42. // Explict overriding of types.
  43. absl::Uniform<int>(*gen, 0, 100);
  44. absl::Uniform<int8_t>(*gen, 0, 100);
  45. absl::Uniform<int16_t>(*gen, 0, 100);
  46. absl::Uniform<uint16_t>(*gen, 0, 100);
  47. absl::Uniform<int32_t>(*gen, 0, 1 << 10);
  48. absl::Uniform<uint32_t>(*gen, 0, 1 << 10);
  49. absl::Uniform<int64_t>(*gen, 0, 1 << 10);
  50. absl::Uniform<uint64_t>(*gen, 0, 1 << 10);
  51. absl::Uniform<float>(*gen, 0.0, 1.0);
  52. absl::Uniform<float>(*gen, 0, 1);
  53. absl::Uniform<float>(*gen, -1, 1);
  54. absl::Uniform<double>(*gen, 0.0, 1.0);
  55. absl::Uniform<float>(*gen, -1.0, 0);
  56. absl::Uniform<double>(*gen, -1.0, 0);
  57. // Tagged
  58. absl::Uniform<double>(absl::IntervalClosedClosed, *gen, 0, 1);
  59. absl::Uniform<double>(absl::IntervalClosedOpen, *gen, 0, 1);
  60. absl::Uniform<double>(absl::IntervalOpenOpen, *gen, 0, 1);
  61. absl::Uniform<double>(absl::IntervalOpenClosed, *gen, 0, 1);
  62. absl::Uniform<double>(absl::IntervalClosedClosed, *gen, 0, 1);
  63. absl::Uniform<double>(absl::IntervalOpenOpen, *gen, 0, 1);
  64. absl::Uniform<int>(absl::IntervalClosedClosed, *gen, 0, 100);
  65. absl::Uniform<int>(absl::IntervalClosedOpen, *gen, 0, 100);
  66. absl::Uniform<int>(absl::IntervalOpenOpen, *gen, 0, 100);
  67. absl::Uniform<int>(absl::IntervalOpenClosed, *gen, 0, 100);
  68. absl::Uniform<int>(absl::IntervalClosedClosed, *gen, 0, 100);
  69. absl::Uniform<int>(absl::IntervalOpenOpen, *gen, 0, 100);
  70. // With *generator as an R-value reference.
  71. absl::Uniform<int>(URBG(), 0, 100);
  72. absl::Uniform<double>(URBG(), 0.0, 1.0);
  73. }
  74. template <typename URBG>
  75. void TestExponential(URBG* gen) {
  76. absl::Exponential<float>(*gen);
  77. absl::Exponential<double>(*gen);
  78. absl::Exponential<double>(URBG());
  79. }
  80. template <typename URBG>
  81. void TestPoisson(URBG* gen) {
  82. // [rand.dist.pois] Indicates that the std::poisson_distribution
  83. // is parameterized by IntType, however MSVC does not allow 8-bit
  84. // types.
  85. absl::Poisson<int>(*gen);
  86. absl::Poisson<int16_t>(*gen);
  87. absl::Poisson<uint16_t>(*gen);
  88. absl::Poisson<int32_t>(*gen);
  89. absl::Poisson<uint32_t>(*gen);
  90. absl::Poisson<int64_t>(*gen);
  91. absl::Poisson<uint64_t>(*gen);
  92. absl::Poisson<uint64_t>(URBG());
  93. }
  94. template <typename URBG>
  95. void TestBernoulli(URBG* gen) {
  96. absl::Bernoulli(*gen, 0.5);
  97. absl::Bernoulli(*gen, 0.5);
  98. }
  99. template <typename URBG>
  100. void TestZipf(URBG* gen) {
  101. absl::Zipf<int>(*gen, 100);
  102. absl::Zipf<int8_t>(*gen, 100);
  103. absl::Zipf<int16_t>(*gen, 100);
  104. absl::Zipf<uint16_t>(*gen, 100);
  105. absl::Zipf<int32_t>(*gen, 1 << 10);
  106. absl::Zipf<uint32_t>(*gen, 1 << 10);
  107. absl::Zipf<int64_t>(*gen, 1 << 10);
  108. absl::Zipf<uint64_t>(*gen, 1 << 10);
  109. absl::Zipf<uint64_t>(URBG(), 1 << 10);
  110. }
  111. template <typename URBG>
  112. void TestGaussian(URBG* gen) {
  113. absl::Gaussian<float>(*gen, 1.0, 1.0);
  114. absl::Gaussian<double>(*gen, 1.0, 1.0);
  115. absl::Gaussian<double>(URBG(), 1.0, 1.0);
  116. }
  117. template <typename URBG>
  118. void TestLogNormal(URBG* gen) {
  119. absl::LogUniform<int>(*gen, 0, 100);
  120. absl::LogUniform<int8_t>(*gen, 0, 100);
  121. absl::LogUniform<int16_t>(*gen, 0, 100);
  122. absl::LogUniform<uint16_t>(*gen, 0, 100);
  123. absl::LogUniform<int32_t>(*gen, 0, 1 << 10);
  124. absl::LogUniform<uint32_t>(*gen, 0, 1 << 10);
  125. absl::LogUniform<int64_t>(*gen, 0, 1 << 10);
  126. absl::LogUniform<uint64_t>(*gen, 0, 1 << 10);
  127. absl::LogUniform<uint64_t>(URBG(), 0, 1 << 10);
  128. }
  129. template <typename URBG>
  130. void CompatibilityTest() {
  131. URBG gen;
  132. TestUniform(&gen);
  133. TestExponential(&gen);
  134. TestPoisson(&gen);
  135. TestBernoulli(&gen);
  136. TestZipf(&gen);
  137. TestGaussian(&gen);
  138. TestLogNormal(&gen);
  139. }
  140. TEST(std_mt19937_64, Compatibility) {
  141. // Validate with std::mt19937_64
  142. CompatibilityTest<std::mt19937_64>();
  143. }
  144. TEST(BitGen, Compatibility) {
  145. // Validate with absl::BitGen
  146. CompatibilityTest<absl::BitGen>();
  147. }
  148. TEST(InsecureBitGen, Compatibility) {
  149. // Validate with absl::InsecureBitGen
  150. CompatibilityTest<absl::InsecureBitGen>();
  151. }
  152. } // namespace