uniform_helper_test.cc 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. // Copyright 2017 The Abseil Authors.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // https://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "absl/random/internal/uniform_helper.h"
  15. #include <cmath>
  16. #include <cstdint>
  17. #include <random>
  18. #include "gtest/gtest.h"
  19. namespace {
  20. using absl::IntervalClosedClosedTag;
  21. using absl::IntervalClosedOpenTag;
  22. using absl::IntervalOpenClosedTag;
  23. using absl::IntervalOpenOpenTag;
  24. using absl::random_internal::uniform_inferred_return_t;
  25. using absl::random_internal::uniform_lower_bound;
  26. using absl::random_internal::uniform_upper_bound;
  27. class UniformHelperTest : public testing::Test {};
  28. TEST_F(UniformHelperTest, UniformBoundFunctionsGeneral) {
  29. constexpr IntervalClosedClosedTag IntervalClosedClosed;
  30. constexpr IntervalClosedOpenTag IntervalClosedOpen;
  31. constexpr IntervalOpenClosedTag IntervalOpenClosed;
  32. constexpr IntervalOpenOpenTag IntervalOpenOpen;
  33. // absl::uniform_int_distribution natively assumes IntervalClosedClosed
  34. // absl::uniform_real_distribution natively assumes IntervalClosedOpen
  35. EXPECT_EQ(uniform_lower_bound(IntervalOpenClosed, 0, 100), 1);
  36. EXPECT_EQ(uniform_lower_bound(IntervalOpenOpen, 0, 100), 1);
  37. EXPECT_GT(uniform_lower_bound<float>(IntervalOpenClosed, 0, 1.0), 0);
  38. EXPECT_GT(uniform_lower_bound<float>(IntervalOpenOpen, 0, 1.0), 0);
  39. EXPECT_GT(uniform_lower_bound<double>(IntervalOpenClosed, 0, 1.0), 0);
  40. EXPECT_GT(uniform_lower_bound<double>(IntervalOpenOpen, 0, 1.0), 0);
  41. EXPECT_EQ(uniform_lower_bound(IntervalClosedClosed, 0, 100), 0);
  42. EXPECT_EQ(uniform_lower_bound(IntervalClosedOpen, 0, 100), 0);
  43. EXPECT_EQ(uniform_lower_bound<float>(IntervalClosedClosed, 0, 1.0), 0);
  44. EXPECT_EQ(uniform_lower_bound<float>(IntervalClosedOpen, 0, 1.0), 0);
  45. EXPECT_EQ(uniform_lower_bound<double>(IntervalClosedClosed, 0, 1.0), 0);
  46. EXPECT_EQ(uniform_lower_bound<double>(IntervalClosedOpen, 0, 1.0), 0);
  47. EXPECT_EQ(uniform_upper_bound(IntervalOpenOpen, 0, 100), 99);
  48. EXPECT_EQ(uniform_upper_bound(IntervalClosedOpen, 0, 100), 99);
  49. EXPECT_EQ(uniform_upper_bound<float>(IntervalOpenOpen, 0, 1.0), 1.0);
  50. EXPECT_EQ(uniform_upper_bound<float>(IntervalClosedOpen, 0, 1.0), 1.0);
  51. EXPECT_EQ(uniform_upper_bound<double>(IntervalOpenOpen, 0, 1.0), 1.0);
  52. EXPECT_EQ(uniform_upper_bound<double>(IntervalClosedOpen, 0, 1.0), 1.0);
  53. EXPECT_EQ(uniform_upper_bound(IntervalOpenClosed, 0, 100), 100);
  54. EXPECT_EQ(uniform_upper_bound(IntervalClosedClosed, 0, 100), 100);
  55. EXPECT_GT(uniform_upper_bound<float>(IntervalOpenClosed, 0, 1.0), 1.0);
  56. EXPECT_GT(uniform_upper_bound<float>(IntervalClosedClosed, 0, 1.0), 1.0);
  57. EXPECT_GT(uniform_upper_bound<double>(IntervalOpenClosed, 0, 1.0), 1.0);
  58. EXPECT_GT(uniform_upper_bound<double>(IntervalClosedClosed, 0, 1.0), 1.0);
  59. // Negative value tests
  60. EXPECT_EQ(uniform_lower_bound(IntervalOpenClosed, -100, -1), -99);
  61. EXPECT_EQ(uniform_lower_bound(IntervalOpenOpen, -100, -1), -99);
  62. EXPECT_GT(uniform_lower_bound<float>(IntervalOpenClosed, -2.0, -1.0), -2.0);
  63. EXPECT_GT(uniform_lower_bound<float>(IntervalOpenOpen, -2.0, -1.0), -2.0);
  64. EXPECT_GT(uniform_lower_bound<double>(IntervalOpenClosed, -2.0, -1.0), -2.0);
  65. EXPECT_GT(uniform_lower_bound<double>(IntervalOpenOpen, -2.0, -1.0), -2.0);
  66. EXPECT_EQ(uniform_lower_bound(IntervalClosedClosed, -100, -1), -100);
  67. EXPECT_EQ(uniform_lower_bound(IntervalClosedOpen, -100, -1), -100);
  68. EXPECT_EQ(uniform_lower_bound<float>(IntervalClosedClosed, -2.0, -1.0), -2.0);
  69. EXPECT_EQ(uniform_lower_bound<float>(IntervalClosedOpen, -2.0, -1.0), -2.0);
  70. EXPECT_EQ(uniform_lower_bound<double>(IntervalClosedClosed, -2.0, -1.0),
  71. -2.0);
  72. EXPECT_EQ(uniform_lower_bound<double>(IntervalClosedOpen, -2.0, -1.0), -2.0);
  73. EXPECT_EQ(uniform_upper_bound(IntervalOpenOpen, -100, -1), -2);
  74. EXPECT_EQ(uniform_upper_bound(IntervalClosedOpen, -100, -1), -2);
  75. EXPECT_EQ(uniform_upper_bound<float>(IntervalOpenOpen, -2.0, -1.0), -1.0);
  76. EXPECT_EQ(uniform_upper_bound<float>(IntervalClosedOpen, -2.0, -1.0), -1.0);
  77. EXPECT_EQ(uniform_upper_bound<double>(IntervalOpenOpen, -2.0, -1.0), -1.0);
  78. EXPECT_EQ(uniform_upper_bound<double>(IntervalClosedOpen, -2.0, -1.0), -1.0);
  79. EXPECT_EQ(uniform_upper_bound(IntervalOpenClosed, -100, -1), -1);
  80. EXPECT_EQ(uniform_upper_bound(IntervalClosedClosed, -100, -1), -1);
  81. EXPECT_GT(uniform_upper_bound<float>(IntervalOpenClosed, -2.0, -1.0), -1.0);
  82. EXPECT_GT(uniform_upper_bound<float>(IntervalClosedClosed, -2.0, -1.0), -1.0);
  83. EXPECT_GT(uniform_upper_bound<double>(IntervalOpenClosed, -2.0, -1.0), -1.0);
  84. EXPECT_GT(uniform_upper_bound<double>(IntervalClosedClosed, -2.0, -1.0),
  85. -1.0);
  86. EXPECT_GT(uniform_lower_bound(IntervalOpenClosed, 1.0, 2.0), 1.0);
  87. EXPECT_LT(uniform_lower_bound(IntervalOpenClosed, 1.0, +0.0), 1.0);
  88. EXPECT_LT(uniform_lower_bound(IntervalOpenClosed, 1.0, -0.0), 1.0);
  89. EXPECT_LT(uniform_lower_bound(IntervalOpenClosed, 1.0, -1.0), 1.0);
  90. }
  91. TEST_F(UniformHelperTest, UniformBoundFunctionsIntBounds) {
  92. // Verifies the saturating nature of uniform_lower_bound and
  93. // uniform_upper_bound
  94. constexpr IntervalOpenOpenTag IntervalOpenOpen;
  95. // uint max.
  96. constexpr auto m = (std::numeric_limits<uint64_t>::max)();
  97. EXPECT_EQ(1, uniform_lower_bound(IntervalOpenOpen, 0u, 0u));
  98. EXPECT_EQ(m, uniform_lower_bound(IntervalOpenOpen, m, m));
  99. EXPECT_EQ(m, uniform_lower_bound(IntervalOpenOpen, m - 1, m - 1));
  100. EXPECT_EQ(0, uniform_upper_bound(IntervalOpenOpen, 0u, 0u));
  101. EXPECT_EQ(m - 1, uniform_upper_bound(IntervalOpenOpen, m, m));
  102. // int min/max
  103. constexpr auto l = (std::numeric_limits<int64_t>::min)();
  104. constexpr auto r = (std::numeric_limits<int64_t>::max)();
  105. EXPECT_EQ(1, uniform_lower_bound(IntervalOpenOpen, 0, 0));
  106. EXPECT_EQ(l + 1, uniform_lower_bound(IntervalOpenOpen, l, l));
  107. EXPECT_EQ(r, uniform_lower_bound(IntervalOpenOpen, r - 1, r - 1));
  108. EXPECT_EQ(r, uniform_lower_bound(IntervalOpenOpen, r, r));
  109. EXPECT_EQ(-1, uniform_upper_bound(IntervalOpenOpen, 0, 0));
  110. EXPECT_EQ(l, uniform_upper_bound(IntervalOpenOpen, l, l));
  111. EXPECT_EQ(r - 1, uniform_upper_bound(IntervalOpenOpen, r, r));
  112. }
  113. TEST_F(UniformHelperTest, UniformBoundFunctionsRealBounds) {
  114. // absl::uniform_real_distribution natively assumes IntervalClosedOpen;
  115. // use the inverse here so each bound has to change.
  116. constexpr IntervalOpenClosedTag IntervalOpenClosed;
  117. // Edge cases: the next value toward itself is itself.
  118. EXPECT_EQ(1.0, uniform_lower_bound(IntervalOpenClosed, 1.0, 1.0));
  119. EXPECT_EQ(1.0f, uniform_lower_bound(IntervalOpenClosed, 1.0f, 1.0f));
  120. // rightmost and leftmost finite values.
  121. constexpr auto r = (std::numeric_limits<double>::max)();
  122. const auto re = std::nexttoward(r, 0.0);
  123. constexpr auto l = -r;
  124. const auto le = std::nexttoward(l, 0.0);
  125. EXPECT_EQ(l, uniform_lower_bound(IntervalOpenClosed, l, l)); // (l,l)
  126. EXPECT_EQ(r, uniform_lower_bound(IntervalOpenClosed, r, r)); // (r,r)
  127. EXPECT_EQ(le, uniform_lower_bound(IntervalOpenClosed, l, r)); // (l,r)
  128. EXPECT_EQ(le, uniform_lower_bound(IntervalOpenClosed, l, 0.0)); // (l, 0)
  129. EXPECT_EQ(le, uniform_lower_bound(IntervalOpenClosed, l, le)); // (l, le)
  130. EXPECT_EQ(r, uniform_lower_bound(IntervalOpenClosed, re, r)); // (re, r)
  131. EXPECT_EQ(le, uniform_upper_bound(IntervalOpenClosed, l, l)); // (l,l)
  132. EXPECT_EQ(r, uniform_upper_bound(IntervalOpenClosed, r, r)); // (r,r)
  133. EXPECT_EQ(r, uniform_upper_bound(IntervalOpenClosed, l, r)); // (l,r)
  134. EXPECT_EQ(r, uniform_upper_bound(IntervalOpenClosed, l, re)); // (l,re)
  135. EXPECT_EQ(r, uniform_upper_bound(IntervalOpenClosed, 0.0, r)); // (0, r)
  136. EXPECT_EQ(r, uniform_upper_bound(IntervalOpenClosed, re, r)); // (re, r)
  137. EXPECT_EQ(r, uniform_upper_bound(IntervalOpenClosed, le, re)); // (le, re)
  138. const double e = std::nextafter(1.0, 2.0); // 1 + epsilon
  139. const double f = std::nextafter(1.0, 0.0); // 1 - epsilon
  140. // (1.0, 1.0 + epsilon)
  141. EXPECT_EQ(e, uniform_lower_bound(IntervalOpenClosed, 1.0, e));
  142. EXPECT_EQ(std::nextafter(e, 2.0),
  143. uniform_upper_bound(IntervalOpenClosed, 1.0, e));
  144. // (1.0-epsilon, 1.0)
  145. EXPECT_EQ(1.0, uniform_lower_bound(IntervalOpenClosed, f, 1.0));
  146. EXPECT_EQ(e, uniform_upper_bound(IntervalOpenClosed, f, 1.0));
  147. // denorm cases.
  148. const double g = std::numeric_limits<double>::denorm_min();
  149. const double h = std::nextafter(g, 1.0);
  150. // (0, denorm_min)
  151. EXPECT_EQ(g, uniform_lower_bound(IntervalOpenClosed, 0.0, g));
  152. EXPECT_EQ(h, uniform_upper_bound(IntervalOpenClosed, 0.0, g));
  153. // (denorm_min, 1.0)
  154. EXPECT_EQ(h, uniform_lower_bound(IntervalOpenClosed, g, 1.0));
  155. EXPECT_EQ(e, uniform_upper_bound(IntervalOpenClosed, g, 1.0));
  156. // Edge cases: invalid bounds.
  157. EXPECT_EQ(f, uniform_lower_bound(IntervalOpenClosed, 1.0, -1.0));
  158. }
  159. struct Invalid {};
  160. template <typename A, typename B>
  161. auto InferredUniformReturnT(int) -> uniform_inferred_return_t<A, B>;
  162. template <typename, typename>
  163. Invalid InferredUniformReturnT(...);
  164. // Given types <A, B, Expect>, CheckArgsInferType() verifies that
  165. //
  166. // uniform_inferred_return_t<A, B> and
  167. // uniform_inferred_return_t<B, A>
  168. //
  169. // returns the type "Expect".
  170. //
  171. // This interface can also be used to assert that a given inferred return types
  172. // are invalid. Writing:
  173. //
  174. // CheckArgsInferType<float, int, Invalid>()
  175. //
  176. // will assert that this overload does not exist.
  177. template <typename A, typename B, typename Expect>
  178. void CheckArgsInferType() {
  179. static_assert(
  180. absl::conjunction<
  181. std::is_same<Expect, decltype(InferredUniformReturnT<A, B>(0))>,
  182. std::is_same<Expect,
  183. decltype(InferredUniformReturnT<B, A>(0))>>::value,
  184. "");
  185. }
  186. TEST_F(UniformHelperTest, UniformTypeInference) {
  187. // Infers common types.
  188. CheckArgsInferType<uint16_t, uint16_t, uint16_t>();
  189. CheckArgsInferType<uint32_t, uint32_t, uint32_t>();
  190. CheckArgsInferType<uint64_t, uint64_t, uint64_t>();
  191. CheckArgsInferType<int16_t, int16_t, int16_t>();
  192. CheckArgsInferType<int32_t, int32_t, int32_t>();
  193. CheckArgsInferType<int64_t, int64_t, int64_t>();
  194. CheckArgsInferType<float, float, float>();
  195. CheckArgsInferType<double, double, double>();
  196. // Properly promotes uint16_t.
  197. CheckArgsInferType<uint16_t, uint32_t, uint32_t>();
  198. CheckArgsInferType<uint16_t, uint64_t, uint64_t>();
  199. CheckArgsInferType<uint16_t, int32_t, int32_t>();
  200. CheckArgsInferType<uint16_t, int64_t, int64_t>();
  201. CheckArgsInferType<uint16_t, float, float>();
  202. CheckArgsInferType<uint16_t, double, double>();
  203. // Properly promotes int16_t.
  204. CheckArgsInferType<int16_t, int32_t, int32_t>();
  205. CheckArgsInferType<int16_t, int64_t, int64_t>();
  206. CheckArgsInferType<int16_t, float, float>();
  207. CheckArgsInferType<int16_t, double, double>();
  208. // Invalid (u)int16_t-pairings do not compile.
  209. // See "CheckArgsInferType" comments above, for how this is achieved.
  210. CheckArgsInferType<uint16_t, int16_t, Invalid>();
  211. CheckArgsInferType<int16_t, uint32_t, Invalid>();
  212. CheckArgsInferType<int16_t, uint64_t, Invalid>();
  213. // Properly promotes uint32_t.
  214. CheckArgsInferType<uint32_t, uint64_t, uint64_t>();
  215. CheckArgsInferType<uint32_t, int64_t, int64_t>();
  216. CheckArgsInferType<uint32_t, double, double>();
  217. // Properly promotes int32_t.
  218. CheckArgsInferType<int32_t, int64_t, int64_t>();
  219. CheckArgsInferType<int32_t, double, double>();
  220. // Invalid (u)int32_t-pairings do not compile.
  221. CheckArgsInferType<uint32_t, int32_t, Invalid>();
  222. CheckArgsInferType<int32_t, uint64_t, Invalid>();
  223. CheckArgsInferType<int32_t, float, Invalid>();
  224. CheckArgsInferType<uint32_t, float, Invalid>();
  225. // Invalid (u)int64_t-pairings do not compile.
  226. CheckArgsInferType<uint64_t, int64_t, Invalid>();
  227. CheckArgsInferType<int64_t, float, Invalid>();
  228. CheckArgsInferType<int64_t, double, Invalid>();
  229. // Properly promotes float.
  230. CheckArgsInferType<float, double, double>();
  231. }
  232. } // namespace