diff --git a/include/boost/math/distributions/dirichlet.hpp b/include/boost/math/distributions/dirichlet.hpp new file mode 100644 index 0000000000..b9daf21ff5 --- /dev/null +++ b/include/boost/math/distributions/dirichlet.hpp @@ -0,0 +1,450 @@ +// boost/math/distributions/dirichlet.hpp + +// Copyright Mrityunjay Tripathi 2020. + +// Use, modification and distribution are subject to the +// Boost Software License, Version 1.0. +// (See accompanying file LICENSE_1_0.txt +// or copy at http://www.boost.org/LICENSE_1_0.txt) + +// https://en.wikipedia.org/wiki/Dirichlet_distribution +// https://mast.queensu.ca/~communications/Papers/msc-jiayu-lin.pdf + +// The Dirichlet distribution is a family of continuous multivariate probability +// distributions parameterized by a vector 'alpha' of positive reals. +// It is a multivariate generalization of the beta distribution, hence its +// alternative name of multivariate beta distribution (MBD). +// Dirichlet distributions are commonly used as prior distributions in +// Bayesian statistics, and in fact the Dirichlet distribution is the +// conjugate prior of the categorical distribution and multinomial distribution. + +#ifndef BOOST_MATH_DISTRIBUTIONS_DIRICHLET_HPP +#define BOOST_MATH_DISTRIBUTIONS_DIRICHLET_HPP + +#include +#include +#include +#include +#include +#include +#include + +#if defined(BOOST_MSVC) +#pragma warning(push) +#pragma warning(disable : 4702) // unreachable code +// in domain_error_imp in error_handling +#endif + +#include + +namespace boost +{ +namespace math +{ +namespace dirichlet_detail +{ +// Common error checking routines for dirichlet distribution function: +template +inline bool check_alpha(const char *function, + const RandomAccessContainer &alpha, + typename RandomAccessContainer::value_type *result, + const Policy &pol) +{ + using RealType = typename RandomAccessContainer::value_type; + if (alpha.size() < 1) + { + *result = policies::raise_domain_error( + function, + "Size of alpha vector is %1%, but must be > 0 !", alpha.size(), pol); + return false; + } + for (decltype(alpha.size()) i = 0; i < alpha.size(); ++i) + { + if (!(boost::math::isfinite)(alpha[i]) || (alpha[i] <= 0)) + { + *result = policies::raise_domain_error( + function, + "alpha Parameter is %1%, but must be > 0 !", alpha[i], pol); + return false; + } + } + return true; +} // bool check_alpha + +template +inline bool check_x(const char *function, + const RandomAccessContainer &x, + typename RandomAccessContainer::value_type *result, + const Policy &pol) +{ + using RealType = typename RandomAccessContainer::value_type; + if (x.size() < 1) + { + *result = policies::raise_domain_error( + function, + "Size of x is %1%, but must be > 0 !", x.size(), pol); + return false; + } + for (decltype(x.size()) i = 0; i < x.size(); ++i) + { + if (!(boost::math::isfinite)(x[i]) || (x[i] < 0) || (x[i] > 1)) + { + *result = policies::raise_domain_error( + function, + "x argument is %1%, but must be >= 0 and <= 1 !", x[i], pol); + return false; + } + } + RealType s = accumulate(x.begin(), x.end(), RealType(0)); + if (s > static_cast(1.0)) + { + *result = policies::raise_domain_error( + function, + "Sum of quantiles is %1%, but must be <= 1 !", s, pol); + return false; + } + return true; +} // bool check_x + + +template +inline bool check_alpha_and_x(const char *function, + const RandomAccessContainer &alpha, + const RandomAccessContainer &x, + typename RandomAccessContainer::value_type *result, + const Policy &pol) +{ + return check_alpha(function, alpha, result, pol) && check_x(function, x, result, pol); +} // bool check_dist_and_x + +template +inline bool check_mean(const char *function, + const RandomAccessContainer &mean, + typename RandomAccessContainer::value_type *result, + const Policy &pol) +{ + using RealType = typename RandomAccessContainer::value_type; + if (mean.size() < 1) + { + *result = policies::raise_domain_error( + function, + "Size of mean vector is %1%, but must be > 0 !", mean.size(), pol); + return false; + } + for (decltype(mean.size()) i = 0; i < mean.size(); ++i) + { + if (!(boost::math::isfinite)(mean[i]) || (mean[i] <= 0)) + { + *result = policies::raise_domain_error( + function, + "mean argument is %1%, but must be > 0 !", mean[i], pol); + return false; + } + } + return true; +} // bool check_mean + +template +inline bool check_variance(const char *function, + const RandomAccessContainer &variance, + typename RandomAccessContainer::value_type *result, + const Policy &pol) +{ + using RealType = typename RandomAccessContainer::value_type; + using std::invalid_argument; + if (variance.size() < 1) + { + *result = policies::raise_domain_error( + function, + "Size of variance vector is %1%, but must be > 0 !", variance.size(), pol); + return false; + } + for (decltype(variance.size()) i = 0; i < variance.size(); ++i) + { + if (!(boost::math::isfinite)(variance[i]) || (variance[i] <= 0)) + { + *result = policies::raise_domain_error( + function, + "variance argument is %1%, but must be > 0 !", variance[i], pol); + return false; + } + } + return true; +} // bool check_variance + +template +inline bool check_mean_and_variance(const char *function, + const RandomAccessContainer &mean, + const RandomAccessContainer &variance, + typename RandomAccessContainer::value_type *result, + const Policy &pol) +{ + return check_mean(function, mean, result, pol) && check_variance(function, variance, result, pol); +} // bool check_mean_and_variance +} // namespace dirichlet_detail + +template , class Policy = policies::policy<>> +class dirichlet_distribution +{ + using RealType = typename RandomAccessContainer::value_type; + +public: + dirichlet_distribution(RandomAccessContainer &&alpha) : m_alpha(alpha) + { + RealType result = 0; + const char *function = "boost::math::dirichlet_distribution<%1%>::dirichlet_distribution"; + dirichlet_detail::check_alpha(function, alpha, &result, Policy()); + } // dirichlet_distribution constructor. + + // Get the concentration parameters. + const RandomAccessContainer &get_alpha() const { return m_alpha; } + + // Get the order of concentration parameters. + auto order() const { return m_alpha.size(); } + + // Get alpha from mean and variance. + auto find_alpha( + RandomAccessContainer &mean, // Expected value of mean. + RandomAccessContainer &variance) // Expected value of variance. + { + assert(("Dimensions of mean and variance must be same!", mean.size() == variance.size())); + static const char *function = "boost::math::dirichlet_distribution<%1%>::find_alpha"; + RealType result = 0; // of error checks. + if (!dirichlet_detail::check_mean_and_variance(function, mean, variance, &result, Policy())) + { + return result; + } + for (decltype(mean.size()) i = 0; i < mean.size(); ++i) + { + m_alpha[i] = mean[i] * (((mean[i] * (1 - mean[i])) / variance[i]) - 1); + } + } // void find_alpha + + RealType normalizing_constant(RealType b = 0.0) const + { + // B(a1,a2,...ak) = (tgamma(a1)*tgamma(a2)...*tgamma(ak)/tgamma(a1+a2+...+ak) + RealType mb = 1.0; + RealType alpha_sum = accumulate(m_alpha.begin(), m_alpha.end(), b * m_alpha.size()); + for (decltype(m_alpha.size()) i = 0; i < m_alpha.size(); ++i) + { + mb *= tgamma(m_alpha[i] + b); + } + mb /= tgamma(alpha_sum); + return mb; + } // normalizing_constant + + RealType sum_alpha() const + { + RealType init = 0.0; + return accumulate(m_alpha.begin(), m_alpha.end(), init); + } // sum_alpha + +private: + RandomAccessContainer m_alpha; // https://en.wikipedia.org/wiki/Concentration_parameter. +}; // template class dirichlet_distribution + +template +inline const std::pair< + typename RandomAccessContainer::value_type, + typename RandomAccessContainer::value_type> +range(const dirichlet_distribution & /* dist */) +{ // Range of permissible values for random variable x. + using boost::math::tools::max_value; + using RealType = typename RandomAccessContainer::value_type; + return std::pair(static_cast(0), static_cast(1)); +} + +template +inline const std::pair< + typename RandomAccessContainer::value_type, + typename RandomAccessContainer::value_type> +support(const dirichlet_distribution & /* dist */) +{ // Range of supported values for random variable x. + // This is range where cdf rises from 0 to 1, and outside it, the pdf is zero. + using RealType = typename RandomAccessContainer::value_type; + return std::pair(static_cast(0), static_cast(1)); +} + +template +inline RandomAccessContainer mean(const dirichlet_distribution &dist) +{ // Mean of dirichlet distribution = c[i]/sum(c). + using RealType = typename RandomAccessContainer::value_type; + RealType A = dist.sum_alpha(); + RandomAccessContainer m(dist.order()); + for (decltype(dist.order()) i = 0; i < dist.order(); ++i) + { + m[i] = dist.get_alpha()[i] / A; + } + return m; +} // mean + +template +inline RandomAccessContainer variance(const dirichlet_distribution &dist) +{ + using RealType = typename RandomAccessContainer::value_type; + RealType A = dist.sum_alpha(); + RandomAccessContainer v(dist.order()); + for (decltype(dist.order()) i = 0; i < dist.order(); ++i) + { + v[i] = (dist.get_alpha()[i] / A) * (1 - dist.get_alpha()[i] / A) / (1 + A); + } + return v; +} // variance + +template +inline RandomAccessContainer standard_deviation(const dirichlet_distribution &dist) +{ + using std::sqrt; + RandomAccessContainer std = variance(dist); + for (decltype(dist.order()) i = 0; i < std.size(); ++i) + { + std[i] = sqrt(std[i]); + } + return std; +} // standard_deviation + +template +inline RandomAccessContainer mode(const dirichlet_distribution &dist) +{ + using RealType = typename RandomAccessContainer::value_type; + static const char *function = "boost::math::mode(dirichlet_distribution<%1%> const&)"; + RandomAccessContainer result(1, 0); + RealType A = dist.sum_alpha(); + RandomAccessContainer m(dist.order()); + for (decltype(dist.order()) i = 0; i < m.size(); ++i) + { + if (dist.get_alpha()[i] <= 1) + { + result[0] = policies::raise_domain_error( + function, + "mode undefined for alpha = %1%, must be > 1!", dist.get_alpha()[i], Policy()); + return result; + } + else + { + m[i] = (dist.get_alpha()[i] - 1) / (A - dist.order()); + } + } + return m; +} // mode + +// Differential Entropy of Dirichlet Distribution +template +inline typename RandomAccessContainer::value_type entropy(const dirichlet_distribution &dist) +{ + using RealType = typename RandomAccessContainer::value_type; + using std::log; + RealType ent = log(dist.normalizing_constant()) + (dist.sum_alpha() - dist.order()) * digamma(dist.sum_alpha()); + for (decltype(dist.order()) i = 0; i < dist.order(); ++i) + { + ent -= (dist.get_alpha()[i] - 1) * digamma(dist.get_alpha()[i]); + } + return ent; +} + +template +inline RandomAccessContainer skewness(const dirichlet_distribution &dist) +{ + using RealType = typename RandomAccessContainer::value_type; + using std::sqrt; + RandomAccessContainer s(dist.order()); + RealType A = dist.sum_alpha(); + RealType aj; + for (decltype(dist.order()) i = 0; i < dist.order(); ++i) + { + aj = dist.get_alpha()[i]; + s[i] = sqrt(aj * (A + 1) / (A - aj)) * ((aj + 2) * (aj + 1) * A * A / (aj * (A + 2) * (A - aj)) - 3 - aj * (A + 1) / (A - aj)); + } + return s; +} + +template +inline RandomAccessContainer kurtosis(const dirichlet_distribution &dist) +{ + using RealType = typename RandomAccessContainer::value_type; + using std::pow; + RandomAccessContainer k(dist.order()); + RealType A = dist.sum_alpha(); + RealType aj; + for (decltype(dist.order()) i = 0; i < dist.order(); ++i) + { + aj = dist.get_alpha()[i]; + k[i] = ((aj + 2) * (aj + 1) * ((aj + 3) * A / (A + 3) / aj - 4) + 6 * (aj + 1) * aj / (A + 1) / A - 3 * pow(aj / A, 2)) / std::pow((A - aj) / A / (A + 1), 2); + } + return k; +} + +template +inline RandomAccessContainer kurtosis_excess(const dirichlet_distribution &dist) +{ + RandomAccessContainer ke = kurtosis(dist); + for (decltype(dist.order()) i = 0; i < dist.order(); ++i) + { + ke[i] = ke[i] - 3; + } + return ke; +} + +template +inline typename RandomAccessContainer::value_type pdf( + const dirichlet_distribution &dist, + const RandomAccessContainer &x) +{ // Probability Density/Mass Function. + using RealType = typename RandomAccessContainer::value_type; + using std::pow; + BOOST_FPU_EXCEPTION_GUARD + BOOST_MATH_STD_USING // for ADL of std functions + + const char *function = "boost::math::pdf(dirichlet_distribution<%1%> const&, %1%)"; + RealType result = 0; + if (!dirichlet_detail::check_x(function, x, &result, Policy())) + { + return result; + } + + RealType f = 1; + for (decltype(dist.order()) i = 0; i < dist.order(); ++i) + { + f *= pow(x[i], dist.get_alpha()[i] - 1); + } + f /= dist.normalizing_constant(); + return f; +} // pdf + +template +inline typename RandomAccessContainer::value_type cdf( + const dirichlet_distribution &dist, + const RandomAccessContainer &x) +{ // Cumulative Distribution Function dirichlet. + using RealType = typename RandomAccessContainer::value_type; + using std::pow; + BOOST_MATH_STD_USING // for ADL of std functions + RealType A = dist.sum_alpha(); + const char *function = "boost::math::cdf(dirichlet_distribution<%1%> const&, %1%)"; + RealType result = 0; // Arguments check. + if (!dirichlet_detail::check_x(function, x, &result, Policy())) + { + return result; + } + RealType c = 1; + for (decltype(dist.order()) i = 0; i < dist.order(); ++i) + { + c *= pow(x[i], dist.get_alpha()[i]) / tgamma(dist.get_alpha()[i]) / dist.get_alpha()[i]; + } + c *= tgamma(A); + return c; +} // dirichlet cdf + +} // namespace math +} // namespace boost + +// This include must be at the end, *after* the accessors +// for this distribution have been defined, in order to +// keep compilers that support two-phase lookup happy. +#include + +#if defined(BOOST_MSVC) +#pragma warning(pop) +#endif + +#endif // BOOST_MATH_DISTRIBUTIONS_DIRICHLET_HPP diff --git a/test/test_dirichlet_dist.cpp b/test/test_dirichlet_dist.cpp new file mode 100644 index 0000000000..eb9f8369e7 --- /dev/null +++ b/test/test_dirichlet_dist.cpp @@ -0,0 +1,364 @@ +// test_dirichlet_dist.cpp + +// Copyright Mrityunjay Tripathi 2020. + +// Use, modification and distribution are subject to the +// Boost Software License, Version 1.0. +// (See accompanying file LICENSE_1_0.txt +// or copy at http://www.boost.org/LICENSE_1_0.txt) + +// Basic sanity tests for the Dirichlet Distribution. + +#ifdef _MSC_VER +#pragma warning(disable : 4127) // conditional expression is constant. +#pragma warning(disable : 4996) // POSIX name for this item is deprecated. +#pragma warning(disable : 4224) // nonstandard extension used : formal parameter 'arg' was previously defined as a type. +#endif + +#define BOOST_TEST_MAIN +#define BOOST_MATH_CHECK_THROW +#define BOOST_LIB_DIAGNOSTIC +#define BOOST_TEST_MODULE + +#include +#include +#include // for test_main +#include +#include // for real_concept +#include // for dirichlet_distribution +#include // for BOOST_CHECK_CLOSE_FRACTION +#include "test_out_of_range.hpp" +#include "math_unit_test.hpp" + +using boost::math::dirichlet_distribution; +using boost::math::concepts::real_concept; +using std::domain_error; +using std::numeric_limits; + +template +void test_spot( + RandomAccessContainer &&alpha, // concentration parameters 'a' + RandomAccessContainer &&x, // quantiles 'x' + RandomAccessContainer &&mean, // mean + RandomAccessContainer &&var, // variance + typename RandomAccessContainer::value_type entropy, // entropy + typename RandomAccessContainer::value_type pdf, // pdf + typename RandomAccessContainer::value_type tol) // Test tolerance. +{ + // using RealType = typename RandomAccessContainer::value_type; + typedef RandomAccessContainer V; + boost::math::dirichlet_distribution diri(std::move(alpha)); + + V calc_mean = boost::math::mean(diri); + V calc_variance = boost::math::variance(diri); + + BOOST_CHECK_CLOSE_FRACTION(boost::math::pdf(diri, x), pdf, tol); + BOOST_CHECK_CLOSE_FRACTION(boost::math::entropy(diri), entropy, tol); + + for (decltype(alpha.size()) i = 0; i < alpha.size(); ++i) + { + BOOST_CHECK_CLOSE_FRACTION(calc_mean[i], mean[i], tol); + BOOST_CHECK_CLOSE_FRACTION(calc_variance[i], var[i], tol); + } +} // template void test_spot + +template +void test_spots() +{ + typedef RandomAccessContainer V; + using RealType = typename V::value_type; + RealType tolerance = 1e-8; + + // Error checks: + // Necessary conditions for instantiation: + // 1. alpha.size() > 0. + // 2. alpha[i] > 0. + + V alpha; // alpha.size() == 0. + V x; // x.size() == 0. + BOOST_MATH_CHECK_THROW(dirichlet_distribution(std::move(alpha)), std::domain_error); + alpha.resize(2); + alpha[0] = static_cast(0.35); + alpha[1] = static_cast(-1.72); // alpha[1] < 0. + BOOST_MATH_CHECK_THROW(dirichlet_distribution(std::move(alpha)), std::domain_error); + + // Domain test for pdf. Necessary conditions for pdf: + // 1. alpha[i] > 0. + // 2. x.size() > 0. + // 3. 0 <= x[i] <=1. + // 4. sum(x) <= 1. + alpha[0] = static_cast(0.2); + alpha[1] = static_cast(1.7); + BOOST_MATH_CHECK_THROW(boost::math::pdf(dirichlet_distribution(std::move(alpha)), x), std::domain_error); + x[0] = static_cast(0.5); + x[1] = static_cast(0.5); + BOOST_MATH_CHECK_THROW(boost::math::pdf(dirichlet_distribution(std::move(alpha)), x), std::domain_error); + + alpha[0] = static_cast(1.36); + alpha[1] = static_cast(0.0); // alpha[1] = 0. + x[0] = static_cast(0.47); + x[1] = static_cast(0.53); + BOOST_MATH_CHECK_THROW(boost::math::pdf(dirichlet_distribution(std::move(alpha)), x), std::domain_error); + + alpha[0] = static_cast(1.26); + alpha[1] = static_cast(2.99); + x[0] = static_cast(0.5); + x[1] = static_cast(0.75); // sum(x) > 1.0 + BOOST_MATH_CHECK_THROW(boost::math::pdf(dirichlet_distribution(std::move(alpha)), x), std::domain_error); + + alpha[0] = static_cast(1.56); + alpha[1] = static_cast(4.00); + x[0] = static_cast(0.31); + x[1] = static_cast(-0.03); // x[1] < 0. + BOOST_MATH_CHECK_THROW(boost::math::pdf(dirichlet_distribution(std::move(alpha)), x), std::domain_error); + + alpha[0] = static_cast(1.56); + alpha[1] = static_cast(4.00); + x[0] = static_cast(0.31); + x[1] = static_cast(1.06); // x[1] > 1. + BOOST_MATH_CHECK_THROW(boost::math::pdf(dirichlet_distribution(std::move(alpha)), x), std::domain_error); + + // Domain test for cdf. Necessary conditions for cdf: + // 1. alpha[i] > 0 + // 2. 0 <= x[i] <= 1 + // 3. sum(x) <= 1. + alpha[0] = static_cast(1.56); + alpha[1] = static_cast(4.00); + x[0] = static_cast(0.31); + x[1] = static_cast(1.06); // x[1] > 1. + BOOST_MATH_CHECK_THROW(boost::math::cdf(dirichlet_distribution(std::move(alpha)), x), std::domain_error); + + alpha[0] = static_cast(3.756); + alpha[1] = static_cast(4.91); + x[0] = static_cast(0.31); + x[1] = static_cast(-1.06); // x[1] < 0. + BOOST_MATH_CHECK_THROW(boost::math::cdf(dirichlet_distribution(std::move(alpha)), x), std::domain_error); + + alpha[0] = static_cast(1.56); + alpha[1] = static_cast(-4.00); // alpha[1] < 0 + x[0] = static_cast(0.31); + x[1] = static_cast(0.69); + BOOST_MATH_CHECK_THROW(boost::math::cdf(dirichlet_distribution(std::move(alpha)), x), std::domain_error); + + alpha[0] = static_cast(0.0); + alpha[1] = static_cast(4.00); // alpha[0] = 0. + x[0] = static_cast(0.25); + x[1] = static_cast(0.75); + BOOST_MATH_CHECK_THROW(boost::math::cdf(dirichlet_distribution(std::move(alpha)), x), std::domain_error); + + alpha[0] = static_cast(1.56); + alpha[1] = static_cast(4.00); + x[0] = static_cast(0.31); + x[1] = static_cast(0.71); // sum(x) > 1. + BOOST_MATH_CHECK_THROW(boost::math::cdf(dirichlet_distribution(std::move(alpha)), x), std::domain_error); + + // Domain test for mode. Necessary conditions for mode: + // 1. alpha[i] > 1. + alpha[0] = static_cast(1.0); + alpha[1] = static_cast(1.4); // alpha[0] = 1. + BOOST_MATH_CHECK_THROW(boost::math::mode(dirichlet_distribution(std::move(alpha))), std::domain_error); + + alpha[0] = static_cast(1.56); + alpha[1] = static_cast(0.92); // alpha[1] < 1. + BOOST_MATH_CHECK_THROW(boost::math::mode(dirichlet_distribution(std::move(alpha))), std::domain_error); + + // Some exact values of pdf. + alpha[0] = static_cast(1.0), alpha[1] = static_cast(1.0); + x[0] = static_cast(0.5), x[1] = static_cast(0.5); + BOOST_CHECK_EQUAL(boost::math::pdf(dirichlet_distribution(std::move(alpha)), x), static_cast(1.0)); + + alpha[0] = static_cast(2.0), alpha[1] = static_cast(2.0); + x[0] = static_cast(0.5), x[1] = static_cast(0.5); + BOOST_CHECK_EQUAL(boost::math::pdf(dirichlet_distribution(std::move(alpha)), x), static_cast(1.5)); + + // Checking precalculated values on scipy. + // https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.dirichlet.html + alpha[0] = static_cast(5.778238829); + alpha[1] = static_cast(2.55821892973); + x[0] = static_cast(0.23667289213); + x[1] = static_cast(0.76332710787); + V mean = {static_cast(0.693128783978901), static_cast(0.3068712160210989)}; + V var = {static_cast(0.022781795654775592), static_cast(0.022781795654775592)}; + RealType entropy = static_cast(-0.516646371355904); + RealType pdf = static_cast(0.05866153821852176); + + test_spot(std::move(alpha), + std::move(x), + std::move(mean), + std::move(var), + entropy, pdf, tolerance); + + alpha[0] = static_cast(5.310948003052013); + alpha[1] = static_cast(8.003963132298916); + x[0] = static_cast(0.35042614416132284); + x[1] = static_cast(0.64957385583867716); + mean[0] = static_cast(0.398872207937724); + mean[1] = static_cast(0.601127792062276); + var[0] = static_cast(0.016749888798155716); + var[1] = static_cast(0.016749888798155716); + pdf = static_cast(2.870121181949622); + entropy = static_cast(-0.6347509574442718); + + test_spot(std::move(alpha), + std::move(x), + std::move(mean), + std::move(var), + entropy, pdf, tolerance); + + alpha[0] = static_cast(8.764102220201394); + alpha[1] = static_cast(4.348446856921846); + x[0] = static_cast(0.6037585982123262); + x[1] = static_cast(0.39624140178767375); + mean[0] = static_cast(0.6683751701255137); + mean[1] = static_cast(0.33162482987448627); + var[0] = static_cast(0.015705865813037533); + var[1] = static_cast(0.015705865813037533); + pdf = static_cast(2.473329499915834); + entropy = static_cast(-0.6769547381491741); + + test_spot(std::move(alpha), + std::move(x), + std::move(mean), + std::move(var), + entropy, pdf, tolerance); + + alpha.resize(3); + x.resize(3); + mean.resize(3); + var.resize(3); + alpha[0] = static_cast(5.622313698848736); + alpha[1] = static_cast(0.3516907178071482); + alpha[2] = static_cast(9.15629985496498); + x[0] = static_cast(0.6571425803855344); + x[1] = static_cast(0.2972004956337586); + x[2] = static_cast(0.04565692398070697); + mean[0] = static_cast(0.37159290374577736); + mean[1] = static_cast(0.023244127249099442); + mean[2] = static_cast(0.6051629690051231); + var[0] = static_cast(0.014476578600094457); + var[1] = static_cast(0.0014075269390591417); + var[2] = static_cast(0.014813158259538361); + pdf = static_cast(4.97846312846897e-08); + entropy = static_cast(-4.047215462643532); + + test_spot(std::move(alpha), + std::move(x), + std::move(mean), + std::move(var), + entropy, pdf, tolerance); + + alpha.resize(4); + x.resize(4); + mean.resize(4); + var.resize(4); + alpha[0] = static_cast(5.958168192947443); + alpha[1] = static_cast(6.823198187239482); + alpha[2] = static_cast(6.297779996686504); + alpha[3] = static_cast(4.396226676824867); + x[0] = static_cast(0.15589020332495018); + x[1] = static_cast(0.3893497609653562); + x[2] = static_cast(0.060839680922786556); + x[3] = static_cast(0.393920354786907); + mean[0] = static_cast(0.2538050483508204); + mean[1] = static_cast(0.2906534508155371); + mean[2] = static_cast(0.26827177494818794); + mean[3] = static_cast(0.1872697258854546); + var[0] = static_cast(0.007737902313764369); + var[1] = static_cast(0.00842373359916587); + var[2] = static_cast(0.008020389690635378); + var[3] = static_cast(0.006218486448329886); + pdf = static_cast(0.2649374226055107); + entropy = static_cast(-3.4416182654031537); + + test_spot(std::move(alpha), + std::move(x), + std::move(mean), + std::move(var), + entropy, pdf, tolerance); + + alpha[0] = static_cast(3.1779256968768976); + alpha[1] = static_cast(1.355989101047721); + alpha[2] = static_cast(5.594207813755373); + alpha[3] = static_cast(5.9897453525066355); + x[0] = static_cast(0.3388848203529338); + x[1] = static_cast(0.36731530174264704); + x[2] = static_cast(0.11166014002460622); + x[3] = static_cast(0.1821397378798129); + mean[0] = static_cast(0.19716787008915473); + mean[1] = static_cast(0.08412955758545015); + mean[2] = static_cast(0.34708112922785567); + mean[3] = static_cast(0.37162144309753953); + var[0] = static_cast(0.009247220589902615); + var[1] = static_cast(0.004501248361485874); + var[2] = static_cast(0.013238553973887957); + var[3] = static_cast(0.013641824239806118); + pdf = static_cast(0.06803159432725718); + entropy = static_cast(-3.398201562087422); + + test_spot(std::move(alpha), + std::move(x), + std::move(mean), + std::move(var), + entropy, pdf, tolerance); + + // No longer allow any parameter to be NaN or inf. + if (std::numeric_limits::has_quiet_NaN) + { + RealType not_a_num = std::numeric_limits::quiet_NaN(); + alpha[0] = not_a_num; + alpha[1] = static_cast(0.37); +#ifndef BOOST_NO_EXCEPTIONS + BOOST_MATH_CHECK_THROW(dirichlet_distribution(std::move(alpha)), std::domain_error); +#else + BOOST_MATH_CHECK_THROW(dirichlet_distribution(std::move(alpha)), std::domain_error); +#endif + + // Non-finite parameters should throw. + alpha[0] = static_cast(1.67); + alpha[1] = static_cast(3.8); + x[0] = not_a_num; + x[1] = static_cast(0.5); + dirichlet_distribution w(std::move(alpha)); + BOOST_MATH_CHECK_THROW(boost::math::pdf(w, x), std::domain_error); // x = NaN + BOOST_MATH_CHECK_THROW(boost::math::cdf(w, x), std::domain_error); // x = NaN + } // has_quiet_NaN + + if (std::numeric_limits::has_infinity) + { + // Attempt to construct from non-finite should throw. + RealType infinite = std::numeric_limits::infinity(); + alpha[0] = infinite; + alpha[1] = static_cast(7.2); +#ifndef BOOST_NO_EXCEPTIONS + BOOST_MATH_CHECK_THROW(dirichlet_distribution(std::move(alpha)), std::domain_error); +#else + BOOST_MATH_CHECK_THROW(dirichlet_distribution(std::move(alpha)), std::domain_error); +#endif + alpha[0] = static_cast(1.42); + alpha[1] = static_cast(7.91); + x[0] = static_cast(0.25); + x[1] = infinite; + dirichlet_distribution w(std::move(alpha)); + BOOST_MATH_CHECK_THROW(boost::math::pdf(w, x), std::domain_error); // x = inf + BOOST_MATH_CHECK_THROW(boost::math::cdf(w, x), std::domain_error); // x = inf + x[1] = -infinite; + BOOST_MATH_CHECK_THROW(boost::math::pdf(w, x), std::domain_error); // x = -inf + BOOST_MATH_CHECK_THROW(boost::math::cdf(w, x), std::domain_error); // x = -inf + } +} // test_spots() + +BOOST_AUTO_TEST_CASE(test_main) +{ + BOOST_MATH_CONTROL_FP; + test_spots>(); + + test_spots>(); + + test_spots>(); + + // #ifndef BOOST_MATH_NO_LONG_DOUBLE_MATH_FUNCTIONS + // test_spots(); // Test long double. + // #if !BOOST_WORKAROUND(__BORLANDC__, BOOST_TESTED_AT(0x582)) + // test_spots(boost::math::concepts::real_concept(0.)); // Test real concept. + // #endif +} // BOOST_AUTO_TEST_CASE( test_main ) \ No newline at end of file