diff --git a/stan/math/prim/prob/std_normal_log_qf.hpp b/stan/math/prim/prob/std_normal_log_qf.hpp index b69ab23026c..895b4997a0a 100644 --- a/stan/math/prim/prob/std_normal_log_qf.hpp +++ b/stan/math/prim/prob/std_normal_log_qf.hpp @@ -71,56 +71,55 @@ inline double std_normal_log_qf(double log_p) { -15.76637472711685, -33.82373901099482}; - double val; double log_q = log_p <= LOG_HALF ? log_diff_exp(LOG_HALF, log_p) : log_diff_exp(log_p, LOG_HALF); int log_q_sign = log_p <= LOG_HALF ? -1 : 1; + double log_r = log_q_sign == -1 ? log_p : log1m_exp(log_p); - if (log_q <= -0.85566611005772) { - double log_r = log_diff_exp(-1.71133222011544, 2 * log_q); - double log_agg_a = log_sum_exp(log_a[7] + log_r, log_a[6]); - double log_agg_b = log_sum_exp(log_b[7] + log_r, log_b[6]); - - for (int i = 0; i < 6; i++) { - log_agg_a = log_sum_exp(log_agg_a + log_r, log_a[5 - i]); - log_agg_b = log_sum_exp(log_agg_b + log_r, log_b[5 - i]); - } + if (stan::math::is_inf(log_r)) { + return 0; + } - return log_q_sign * exp(log_q + log_agg_a - log_agg_b); + double log_inner_r; + double log_pre_mult; + const double* num_ptr; + const double* den_ptr; + + static constexpr double LOG_FIVE = LOG_TEN - LOG_TWO; + static constexpr double LOG_16 = LOG_TWO * 4; + static constexpr double LOG_425 = 6.0520891689244171729; + static constexpr double LOG_425_OVER_1000 = LOG_425 - LOG_TEN * 3; + + if (log_q <= LOG_425_OVER_1000) { + log_inner_r = log_diff_exp(LOG_425_OVER_1000 * 2, log_q * 2); + log_pre_mult = log_q; + num_ptr = &log_a[0]; + den_ptr = &log_b[0]; } else { - double log_r = log_q_sign == -1 ? log_p : log1m_exp(log_p); - - if (stan::math::is_inf(log_r)) { - return 0; - } - - log_r = log(sqrt(-log_r)); - - if (log_r <= 1.60943791243410) { - log_r = log_diff_exp(log_r, 0.47000362924573); - double log_agg_c = log_sum_exp(log_c[7] + log_r, log_c[6]); - double log_agg_d = log_sum_exp(log_d[7] + log_r, log_d[6]); - - for (int i = 0; i < 6; i++) { - log_agg_c = log_sum_exp(log_agg_c + log_r, log_c[5 - i]); - log_agg_d = log_sum_exp(log_agg_d + log_r, log_d[5 - i]); - } - val = exp(log_agg_c - log_agg_d); + double log_temp_r = log(-log_r) / 2.0; + if (log_temp_r <= LOG_FIVE) { + log_inner_r = log_diff_exp(log_temp_r, LOG_16 - LOG_TEN); + num_ptr = &log_c[0]; + den_ptr = &log_d[0]; } else { - log_r = log_diff_exp(log_r, 1.60943791243410); - double log_agg_e = log_sum_exp(log_e[7] + log_r, log_e[6]); - double log_agg_f = log_sum_exp(log_f[7] + log_r, log_f[6]); - - for (int i = 0; i < 6; i++) { - log_agg_e = log_sum_exp(log_agg_e + log_r, log_e[5 - i]); - log_agg_f = log_sum_exp(log_agg_f + log_r, log_f[5 - i]); - } - val = exp(log_agg_e - log_agg_f); + log_inner_r = log_diff_exp(log_temp_r, LOG_FIVE); + num_ptr = &log_e[0]; + den_ptr = &log_f[0]; } - if (log_q_sign == -1) - return -val; + log_pre_mult = 0.0; } - return val; + + // As computation requires evaluating r^8, this causes a loss of precision, + // even when on the log space. We can mitigate this by scaling the + // exponentiated result (dividing by 10), since the same scaling is applied + // to the numerator and denominator. + Eigen::VectorXd log_r_pow + = Eigen::ArrayXd::LinSpaced(8, 0, 7) * log_inner_r - LOG_TEN; + Eigen::Map num_map(num_ptr, 8); + Eigen::Map den_map(den_ptr, 8); + double log_result + = log_sum_exp(log_r_pow + num_map) - log_sum_exp(log_r_pow + den_map); + return log_q_sign * exp(log_pre_mult + log_result); } /**