gaussian_distribution_gentables.cc 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. // Copyright 2017 The Abseil Authors.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // https://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Generates gaussian_distribution.cc
  15. //
  16. // $ blaze run :gaussian_distribution_gentables > gaussian_distribution.cc
  17. //
  18. #include "absl/random/gaussian_distribution.h"
  19. #include <cmath>
  20. #include <cstddef>
  21. #include <iostream>
  22. #include <limits>
  23. #include <string>
  24. #include "absl/base/macros.h"
  25. namespace absl {
  26. ABSL_NAMESPACE_BEGIN
  27. namespace random_internal {
  28. namespace {
  29. template <typename T, size_t N>
  30. void FormatArrayContents(std::ostream* os, T (&data)[N]) {
  31. if (!std::numeric_limits<T>::is_exact) {
  32. // Note: T is either an integer or a float.
  33. // float requires higher precision to ensure that values are
  34. // reproduced exactly.
  35. // Trivia: C99 has hexadecimal floating point literals, but C++11 does not.
  36. // Using them would remove all concern of precision loss.
  37. os->precision(std::numeric_limits<T>::max_digits10 + 2);
  38. }
  39. *os << " {";
  40. std::string separator = "";
  41. for (size_t i = 0; i < N; ++i) {
  42. *os << separator << data[i];
  43. if ((i + 1) % 3 != 0) {
  44. separator = ", ";
  45. } else {
  46. separator = ",\n ";
  47. }
  48. }
  49. *os << "}";
  50. }
  51. } // namespace
  52. class TableGenerator : public gaussian_distribution_base {
  53. public:
  54. TableGenerator();
  55. void Print(std::ostream* os);
  56. using gaussian_distribution_base::kMask;
  57. using gaussian_distribution_base::kR;
  58. using gaussian_distribution_base::kV;
  59. private:
  60. Tables tables_;
  61. };
  62. // Ziggurat gaussian initialization. For an explanation of the algorithm, see
  63. // the Marsaglia paper, "The Ziggurat Method for Generating Random Variables".
  64. // http://www.jstatsoft.org/v05/i08/
  65. //
  66. // Further details are available in the Doornik paper
  67. // https://www.doornik.com/research/ziggurat.pdf
  68. //
  69. TableGenerator::TableGenerator() {
  70. // The constants here should match the values in gaussian_distribution.h
  71. static constexpr int kC = kMask + 1;
  72. static_assert((ABSL_ARRAYSIZE(tables_.x) == kC + 1),
  73. "xArray must be length kMask + 2");
  74. static_assert((ABSL_ARRAYSIZE(tables_.x) == ABSL_ARRAYSIZE(tables_.f)),
  75. "fx and x arrays must be identical length");
  76. auto f = [](double x) { return std::exp(-0.5 * x * x); };
  77. auto f_inv = [](double x) { return std::sqrt(-2.0 * std::log(x)); };
  78. tables_.x[0] = kV / f(kR);
  79. tables_.f[0] = f(tables_.x[0]);
  80. tables_.x[1] = kR;
  81. tables_.f[1] = f(tables_.x[1]);
  82. tables_.x[kC] = 0.0;
  83. tables_.f[kC] = f(tables_.x[kC]); // 1.0
  84. for (int i = 2; i < kC; i++) {
  85. double v = (kV / tables_.x[i - 1]) + tables_.f[i - 1];
  86. tables_.x[i] = f_inv(v);
  87. tables_.f[i] = v;
  88. }
  89. }
  90. void TableGenerator::Print(std::ostream* os) {
  91. *os << "// BEGIN GENERATED CODE; DO NOT EDIT\n"
  92. "// clang-format off\n"
  93. "\n"
  94. "#include \"absl/random/gaussian_distribution.h\"\n"
  95. "\n"
  96. "namespace absl {\n"
  97. "ABSL_NAMESPACE_BEGIN\n"
  98. "namespace random_internal {\n"
  99. "\n"
  100. "const gaussian_distribution_base::Tables\n"
  101. " gaussian_distribution_base::zg_ = {\n";
  102. FormatArrayContents(os, tables_.x);
  103. *os << ",\n";
  104. FormatArrayContents(os, tables_.f);
  105. *os << "};\n"
  106. "\n"
  107. "} // namespace random_internal\n"
  108. "ABSL_NAMESPACE_END\n"
  109. "} // namespace absl\n"
  110. "\n"
  111. "// clang-format on\n"
  112. "// END GENERATED CODE";
  113. *os << std::endl;
  114. }
  115. } // namespace random_internal
  116. ABSL_NAMESPACE_END
  117. } // namespace absl
  118. int main(int, char**) {
  119. std::cerr << "\nCopy the output to gaussian_distribution.cc" << std::endl;
  120. absl::random_internal::TableGenerator generator;
  121. generator.Print(&std::cout);
  122. return 0;
  123. }