diff --git a/pytensor_distributions/polyagamma.py b/pytensor_distributions/polyagamma.py new file mode 100644 index 0000000..9a4828f --- /dev/null +++ b/pytensor_distributions/polyagamma.py @@ -0,0 +1,206 @@ +import pytensor.tensor as pt +from pytensor import scan +from pytensor.scan.utils import until + +from pytensor_distributions.helper import ( + cdf_bounds, + continuous_entropy, + continuous_kurtosis, + continuous_mode, + continuous_skewness, + ppf_bounds_cont, +) + + +def _lower_bound(): + return 1e-10 + + +def _upper_bound(h, z): + m = mean(h, z) + s = std(h, z) + return m + 10 * s + + +def _log_cosh_half(z): + """Compute log(cosh(z/2)) in a numerically stable way.""" + abs_half_z = pt.abs(z / 2) + return abs_half_z + pt.log1p(pt.exp(-2 * abs_half_z)) - pt.log(2.0) + + +def _log_pg_density_base(x, h, N=20): + """Log density of PG(h, 0) using truncated alternating series. + + Uses the Jacobi series representation: + f(x; h, 0) = (2^{h-1} / Gamma(h)) * sum_{n=0}^{N-1} (-1)^n + * [Gamma(n+h) / n!] * (2n+h) / sqrt(2*pi*x^3) + * exp(-(2n+h)^2 / (8x)) + + Direct signed summation in shifted linear space avoids the pairing + approach which fails when consecutive terms are not monotonically decreasing. + """ + n = pt.arange(N, dtype="float64") + n_bc = n.reshape((-1,) + (1,) * x.ndim) + + c = 2 * n_bc + h + signs = pt.switch(pt.eq(n_bc % 2, 0), 1.0, -1.0) + + # Log of absolute value of each term (without the global prefactor) + log_abs_term = ( + pt.gammaln(n_bc + h) + - pt.gammaln(n_bc + 1) + + pt.log(c) + - 0.5 * pt.log(2 * pt.pi) + - 1.5 * pt.log(x) + - c**2 / (8 * x) + ) + + # Shift to prevent overflow, then sum with signs in linear space + max_log = pt.max(log_abs_term, axis=0) + shifted = pt.exp(log_abs_term - max_log) + signed_sum = pt.sum(signs * shifted, axis=0) + + # Clamp to small positive value for numerical safety in far tail + log_series = pt.log(pt.maximum(signed_sum, 1e-300)) + max_log + + # Global prefactor: (h-1)*log(2) - gammaln(h) + log_prefactor = (h - 1) * pt.log(2.0) - pt.gammaln(h) + + return log_prefactor + log_series + + +def mean(h, z): + z = pt.as_tensor_variable(z) + small = pt.lt(pt.abs(z), 1e-6) + safe_z = pt.switch(small, pt.ones_like(z), z) + result_general = h / (2 * safe_z) * pt.tanh(safe_z / 2) + result_zero = h / 4.0 + return pt.switch(small, result_zero, result_general) + + +def mode(h, z): + return continuous_mode(_lower_bound(), _upper_bound(h, z), logpdf, h, z) + + +def median(h, z): + return ppf(0.5, h, z) + + +def var(h, z): + z = pt.as_tensor_variable(z) + small = pt.lt(pt.abs(z), 1e-6) + safe_z = pt.switch(small, pt.ones_like(z), z) + result_general = h / (4 * safe_z**3) * (pt.sinh(safe_z) - safe_z) / pt.cosh(safe_z / 2) ** 2 + result_zero = h / 24.0 + return pt.switch(small, result_zero, result_general) + + +def std(h, z): + return pt.sqrt(var(h, z)) + + +def skewness(h, z): + return continuous_skewness(_lower_bound(), _upper_bound(h, z), logpdf, h, z) + + +def kurtosis(h, z): + return continuous_kurtosis(_lower_bound(), _upper_bound(h, z), logpdf, h, z) + + +def entropy(h, z): + return continuous_entropy(_lower_bound(), _upper_bound(h, z), logpdf, h, z) + + +def logpdf(x, h, z): + x = pt.as_tensor_variable(x) + log_tilt = h * _log_cosh_half(z) - z**2 * x / 2 + result = log_tilt + _log_pg_density_base(x, h) + return pt.switch(pt.le(x, 0), -pt.inf, result) + + +def pdf(x, h, z): + return pt.exp(logpdf(x, h, z)) + + +def cdf(x, h, z): + x = pt.as_tensor_variable(x) + n_points = 500 + lower = _lower_bound() + + t = pt.linspace(lower, x, n_points) + pdf_vals = pdf(t, h, z) + dx = (x - lower) / (n_points - 1) + result = dx * (0.5 * pdf_vals[0] + pt.sum(pdf_vals[1:-1], axis=0) + 0.5 * pdf_vals[-1]) + + return cdf_bounds(result, x, 0, pt.inf) + + +def logcdf(x, h, z): + return pt.switch(pt.le(x, 0), -pt.inf, pt.log(cdf(x, h, z))) + + +def sf(x, h, z): + return 1.0 - cdf(x, h, z) + + +def logsf(x, h, z): + return pt.log1p(-cdf(x, h, z)) + + +def ppf(q, h, z, max_iter=50, tol=1e-8): + # Use log-normal approximation as initial guess to avoid Newton oscillation. + # PG(h, z) is positive and right-skewed; a log-normal is a good proxy. + m = mean(h, z) + v = var(h, z) + sigma_ln_sq = pt.log1p(v / m**2) + mu_ln = pt.log(m) - sigma_ln_sq / 2 + sigma_ln = pt.sqrt(sigma_ln_sq) + x0 = pt.exp(mu_ln + sigma_ln * pt.sqrt(2.0) * pt.erfinv(2 * q - 1)) + x0 = pt.maximum(x0, _lower_bound()) + + lb = _lower_bound() + + def step(x_prev): + x_prev_squeezed = pt.squeeze(x_prev) + + cdf_val = cdf(x_prev_squeezed, h, z) + f_x = pt.maximum(pdf(x_prev_squeezed, h, z), 1e-10) + delta = (cdf_val - q) / f_x + + max_step = pt.maximum(pt.abs(x_prev_squeezed), 0.5) + delta = pt.clip(delta, -max_step, max_step) + x_new = pt.maximum(x_prev_squeezed - delta, lb) + + converged = pt.abs(x_new - x_prev_squeezed) < tol + x_new = pt.switch(converged, x_prev_squeezed, x_new) + + all_converged = pt.all(converged) + return pt.shape_padleft(x_new), until(all_converged) + + x_seq = scan(fn=step, outputs_info=pt.shape_padleft(x0), n_steps=max_iter, return_updates=False) + + return ppf_bounds_cont(x_seq[-1].squeeze(), q, 0, pt.inf) + + +def isf(q, h, z): + return ppf(1.0 - q, h, z) + + +def rvs(h, z, size=None, random_state=None): + K = 200 + k = pt.arange(1, K + 1, dtype="float64") + + if size is None: + gamma_size = (K,) + elif isinstance(size, int): + gamma_size = (K, size) + else: + gamma_size = (K, *size) + + gamma_draws = pt.random.gamma(h, scale=1.0, size=gamma_size, rng=random_state) + + z2_term = z**2 / (4 * pt.pi**2) + k_bc = k.reshape((-1,) + (1,) * (gamma_draws.ndim - 1)) + denom = (k_bc - 0.5) ** 2 + z2_term + + return pt.sum(gamma_draws / denom, axis=0) / (2 * pt.pi**2) diff --git a/tests/test_polyagamma.py b/tests/test_polyagamma.py new file mode 100644 index 0000000..1401333 --- /dev/null +++ b/tests/test_polyagamma.py @@ -0,0 +1,185 @@ +"""Test PolyaGamma distribution against empirical samples.""" + +import numpy as np +import pytensor.tensor as pt +import pytest +from numpy.testing import assert_allclose +from scipy.integrate import quad +from scipy.stats import kurtosis, skew + +from pytensor_distributions import polyagamma as PolyaGamma +from tests.helper_scipy import make_params + +PARAMS_LIST = [ + (1.0, 0.0), + (2.0, 2.0), + (0.5, 1.0), +] + + +@pytest.fixture(scope="module") +def samples(): + """Generate samples once for all tests (shared across parametrizations).""" + results = {} + for h, z in PARAMS_LIST: + p_params = make_params(h, z, dtype="float64") + rng_p = pt.random.default_rng(1) + rvs = PolyaGamma.rvs(*p_params, size=10_000, random_state=rng_p).eval() + results[(h, z)] = (p_params, rvs) + return results + + +@pytest.mark.parametrize("params", PARAMS_LIST) +def test_polyagamma_moments(params, samples): + """Theoretical moments should match empirical moments from samples.""" + p_params, rvs = samples[params] + + assert_allclose(PolyaGamma.mean(*p_params).eval(), rvs.mean(), rtol=3e-2, atol=3e-2) + assert_allclose(PolyaGamma.var(*p_params).eval(), rvs.var(), rtol=1e-1, atol=1e-3) + assert_allclose(PolyaGamma.std(*p_params).eval(), rvs.std(), rtol=1e-1, atol=1e-3) + assert_allclose(PolyaGamma.skewness(*p_params).eval(), skew(rvs), rtol=3e-1, atol=1e-2) + assert_allclose(PolyaGamma.kurtosis(*p_params).eval(), kurtosis(rvs), rtol=3e-1, atol=1e-2) + + +@pytest.mark.parametrize("params", PARAMS_LIST) +def test_polyagamma_cdf(params, samples): + """CDF should match empirical CDF and be monotonic on a small grid.""" + p_params, rvs = samples[params] + + sample_x = rvs[:20] + theoretical_cdf = PolyaGamma.cdf(sample_x, *p_params).eval() + for i, x in enumerate(sample_x): + empirical_cdf = np.mean(rvs <= x) + assert_allclose(theoretical_cdf[i], empirical_cdf, rtol=1e-1, atol=1e-3) + + x_grid = np.linspace(np.percentile(rvs, 1), np.percentile(rvs, 99), 50) + cdf_vals = PolyaGamma.cdf(x_grid, *p_params).eval() + assert np.all(np.diff(cdf_vals) >= -1e-4), "CDF is not monotonic" + + +@pytest.mark.parametrize("params", PARAMS_LIST) +def test_polyagamma_cdf_bounds(params, samples): + """CDF should be 0 at lower bound and handle out-of-support.""" + p_params, _ = samples[params] + + extended_vals = np.array([0.0, np.inf, -1.0, -2.0, np.inf, np.inf]) + expected = np.array([0.0, 1.0, 0.0, 0.0, 1.0, 1.0]) + assert_allclose(PolyaGamma.cdf(extended_vals, *p_params).eval(), expected) + + +@pytest.mark.slow +@pytest.mark.parametrize("params", [(1.0, 0.0)]) +def test_polyagamma_ppf(params, samples): + """PPF should match empirical quantiles. Slow due to scan-based solver.""" + p_params, rvs = samples[params] + + q = np.linspace(0.05, 0.95, 10) + theoretical = PolyaGamma.ppf(q, *p_params).eval() + empirical = np.quantile(rvs, q) + assert_allclose(theoretical, empirical, rtol=1e-1, atol=5e-3) + + +@pytest.mark.slow +def test_polyagamma_ppf_cdf_inverse(samples): + """CDF(PPF(q)) should recover q. Slow due to scan-based solver.""" + p_params, _ = samples[(1.0, 0.0)] + + q = np.array([0.1, 0.5, 0.9]) + x_vals = PolyaGamma.ppf(q, *p_params).eval() + recovered = PolyaGamma.cdf(x_vals, *p_params).eval() + assert_allclose(recovered, q, atol=1e-4) + + +@pytest.mark.parametrize("params", PARAMS_LIST) +def test_polyagamma_pdf(params, samples): + """PDF should be non-negative and integrate to ~1.""" + p_params, rvs = samples[params] + + x_grid = np.linspace(0.01, np.percentile(rvs, 99), 100) + pdf_vals = PolyaGamma.pdf(x_grid, *p_params).eval() + assert np.all(pdf_vals >= 0), "PDF has negative values" + + u_b = float(np.percentile(rvs, 99.9)) + result, _ = quad(lambda x: PolyaGamma.pdf(x, *p_params).eval(), 0, u_b) + assert np.abs(result - 1) < 0.02, f"PDF integral = {result}, should be 1" + + +@pytest.mark.parametrize("params", PARAMS_LIST) +def test_polyagamma_pdf_cdf_consistency(params, samples): + """Numerical derivative of CDF should approximate PDF.""" + p_params, rvs = samples[params] + + x_mid = np.linspace(np.percentile(rvs, 5), np.percentile(rvs, 95), 20) + eps = 1e-5 + cdf_plus = PolyaGamma.cdf(x_mid + eps, *p_params).eval() + cdf_minus = PolyaGamma.cdf(x_mid - eps, *p_params).eval() + numerical_pdf = (cdf_plus - cdf_minus) / (2 * eps) + pdf_vals = PolyaGamma.pdf(x_mid, *p_params).eval() + + mask = np.abs(pdf_vals) > 1e-4 + if np.any(mask): + rel_error = np.abs(numerical_pdf[mask] - pdf_vals[mask]) / (np.abs(pdf_vals[mask]) + 1e-10) + assert np.all(rel_error < 1e-2), ( + f"PDF doesn't match CDF derivative. Max rel error: {np.max(rel_error)}" + ) + + +@pytest.mark.parametrize("params", PARAMS_LIST) +def test_polyagamma_sf_complement(params, samples): + """SF + CDF should equal 1.""" + p_params, rvs = samples[params] + + x = rvs[:20] + cdf_vals = PolyaGamma.cdf(x, *p_params).eval() + sf_vals = PolyaGamma.sf(x, *p_params).eval() + assert_allclose(cdf_vals + sf_vals, 1.0, atol=1e-4) + + +@pytest.mark.parametrize("params", PARAMS_LIST) +def test_polyagamma_entropy(params, samples): + """Monte Carlo entropy should match computed entropy.""" + p_params, rvs = samples[params] + + logpdf_vals = PolyaGamma.logpdf(rvs, *p_params).eval() + logpdf_vals = logpdf_vals[np.isfinite(logpdf_vals)] + mc_entropy = -np.mean(logpdf_vals) + computed_entropy = PolyaGamma.entropy(*p_params).eval() + + rel_error = np.abs(mc_entropy - computed_entropy) / (np.abs(computed_entropy) + 1e-10) + assert rel_error < 0.1, f"Entropy mismatch. MC: {mc_entropy}, Computed: {computed_entropy}" + + +def test_polyagamma_mean_z_zero(): + """Mean at z=0 should equal h/4.""" + p_params = make_params(2.0, 0.0, dtype="float64") + result = PolyaGamma.mean(*p_params).eval() + assert_allclose(result, 0.5, rtol=1e-10) + + +def test_polyagamma_var_z_zero(): + """Variance at z=0 should equal h/24.""" + p_params = make_params(2.0, 0.0, dtype="float64") + result = PolyaGamma.var(*p_params).eval() + assert_allclose(result, 2.0 / 24, rtol=1e-10) + + +def test_polyagamma_pdf_positive(): + """PDF should be positive on the support.""" + p_params = make_params(1.0, 1.0, dtype="float64") + x = np.linspace(0.01, 2.0, 100) + pdf_vals = PolyaGamma.pdf(x, *p_params).eval() + assert np.all(pdf_vals > 0) + + +def test_polyagamma_pdf_zero_outside(): + """PDF should be zero for x <= 0.""" + p_params = make_params(1.0, 1.0, dtype="float64") + assert_allclose(PolyaGamma.pdf(-1.0, *p_params).eval(), 0.0) + assert_allclose(PolyaGamma.pdf(0.0, *p_params).eval(), 0.0) + + +def test_polyagamma_logpdf_neginf_outside(): + """Logpdf should be -inf for x <= 0.""" + p_params = make_params(1.0, 1.0, dtype="float64") + assert PolyaGamma.logpdf(-1.0, *p_params).eval() == -np.inf + assert PolyaGamma.logpdf(0.0, *p_params).eval() == -np.inf