Skip to content
81 changes: 40 additions & 41 deletions stan/math/prim/prob/std_normal_log_qf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,56 +71,55 @@ inline double std_normal_log_qf(double log_p) {
-15.76637472711685,
-33.82373901099482};

double val;
double log_q = log_p <= LOG_HALF ? log_diff_exp(LOG_HALF, log_p)
: log_diff_exp(log_p, LOG_HALF);
int log_q_sign = log_p <= LOG_HALF ? -1 : 1;
double log_r = log_q_sign == -1 ? log_p : log1m_exp(log_p);

if (log_q <= -0.85566611005772) {
double log_r = log_diff_exp(-1.71133222011544, 2 * log_q);
double log_agg_a = log_sum_exp(log_a[7] + log_r, log_a[6]);
double log_agg_b = log_sum_exp(log_b[7] + log_r, log_b[6]);

for (int i = 0; i < 6; i++) {
log_agg_a = log_sum_exp(log_agg_a + log_r, log_a[5 - i]);
log_agg_b = log_sum_exp(log_agg_b + log_r, log_b[5 - i]);
}
if (stan::math::is_inf(log_r)) {
return 0;
}

return log_q_sign * exp(log_q + log_agg_a - log_agg_b);
double log_inner_r;
double log_pre_mult;
const double* num_ptr;
const double* den_ptr;

static constexpr double LOG_FIVE = LOG_TEN - LOG_TWO;
static constexpr double LOG_16 = LOG_TWO * 4;
static constexpr double LOG_425 = 6.0520891689244171729;
static constexpr double LOG_425_OVER_1000 = LOG_425 - LOG_TEN * 3;

if (log_q <= LOG_425_OVER_1000) {
log_inner_r = log_diff_exp(LOG_425_OVER_1000 * 2, log_q * 2);
log_pre_mult = log_q;
num_ptr = &log_a[0];
den_ptr = &log_b[0];
} else {
double log_r = log_q_sign == -1 ? log_p : log1m_exp(log_p);

if (stan::math::is_inf(log_r)) {
return 0;
}

log_r = log(sqrt(-log_r));

if (log_r <= 1.60943791243410) {
log_r = log_diff_exp(log_r, 0.47000362924573);
double log_agg_c = log_sum_exp(log_c[7] + log_r, log_c[6]);
double log_agg_d = log_sum_exp(log_d[7] + log_r, log_d[6]);

for (int i = 0; i < 6; i++) {
log_agg_c = log_sum_exp(log_agg_c + log_r, log_c[5 - i]);
log_agg_d = log_sum_exp(log_agg_d + log_r, log_d[5 - i]);
}
val = exp(log_agg_c - log_agg_d);
double log_temp_r = log(-log_r) / 2.0;
if (log_temp_r <= LOG_FIVE) {
log_inner_r = log_diff_exp(log_temp_r, LOG_16 - LOG_TEN);
num_ptr = &log_c[0];
den_ptr = &log_d[0];
} else {
log_r = log_diff_exp(log_r, 1.60943791243410);
double log_agg_e = log_sum_exp(log_e[7] + log_r, log_e[6]);
double log_agg_f = log_sum_exp(log_f[7] + log_r, log_f[6]);

for (int i = 0; i < 6; i++) {
log_agg_e = log_sum_exp(log_agg_e + log_r, log_e[5 - i]);
log_agg_f = log_sum_exp(log_agg_f + log_r, log_f[5 - i]);
}
val = exp(log_agg_e - log_agg_f);
log_inner_r = log_diff_exp(log_temp_r, LOG_FIVE);
num_ptr = &log_e[0];
den_ptr = &log_f[0];
}
if (log_q_sign == -1)
return -val;
log_pre_mult = 0.0;
}
return val;

// As computation requires evaluating r^8, this causes a loss of precision,
// even when on the log space. We can mitigate this by scaling the
// exponentiated result (dividing by 10), since the same scaling is applied
// to the numerator and denominator.
Eigen::VectorXd log_r_pow
= Eigen::ArrayXd::LinSpaced(8, 0, 7) * log_inner_r - LOG_TEN;
Eigen::Map<const Eigen::VectorXd> num_map(num_ptr, 8);
Eigen::Map<const Eigen::VectorXd> den_map(den_ptr, 8);
double log_result
= log_sum_exp(log_r_pow + num_map) - log_sum_exp(log_r_pow + den_map);
return log_q_sign * exp(log_pre_mult + log_result);
}

/**
Expand Down