Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 206 additions & 0 deletions pytensor_distributions/polyagamma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
import pytensor.tensor as pt
from pytensor import scan
from pytensor.scan.utils import until

from pytensor_distributions.helper import (
cdf_bounds,
continuous_entropy,
continuous_kurtosis,
continuous_mode,
continuous_skewness,
ppf_bounds_cont,
)


def _lower_bound():
return 1e-10


def _upper_bound(h, z):
m = mean(h, z)
s = std(h, z)
return m + 10 * s


def _log_cosh_half(z):
"""Compute log(cosh(z/2)) in a numerically stable way."""
abs_half_z = pt.abs(z / 2)
return abs_half_z + pt.log1p(pt.exp(-2 * abs_half_z)) - pt.log(2.0)


def _log_pg_density_base(x, h, N=20):
"""Log density of PG(h, 0) using truncated alternating series.

Uses the Jacobi series representation:
f(x; h, 0) = (2^{h-1} / Gamma(h)) * sum_{n=0}^{N-1} (-1)^n
* [Gamma(n+h) / n!] * (2n+h) / sqrt(2*pi*x^3)
* exp(-(2n+h)^2 / (8x))

Direct signed summation in shifted linear space avoids the pairing
approach which fails when consecutive terms are not monotonically decreasing.
"""
n = pt.arange(N, dtype="float64")
n_bc = n.reshape((-1,) + (1,) * x.ndim)

c = 2 * n_bc + h
signs = pt.switch(pt.eq(n_bc % 2, 0), 1.0, -1.0)

# Log of absolute value of each term (without the global prefactor)
log_abs_term = (
pt.gammaln(n_bc + h)
- pt.gammaln(n_bc + 1)
+ pt.log(c)
- 0.5 * pt.log(2 * pt.pi)
- 1.5 * pt.log(x)
- c**2 / (8 * x)
)

# Shift to prevent overflow, then sum with signs in linear space
max_log = pt.max(log_abs_term, axis=0)
shifted = pt.exp(log_abs_term - max_log)
signed_sum = pt.sum(signs * shifted, axis=0)

# Clamp to small positive value for numerical safety in far tail
log_series = pt.log(pt.maximum(signed_sum, 1e-300)) + max_log

# Global prefactor: (h-1)*log(2) - gammaln(h)
log_prefactor = (h - 1) * pt.log(2.0) - pt.gammaln(h)

return log_prefactor + log_series


def mean(h, z):
z = pt.as_tensor_variable(z)
small = pt.lt(pt.abs(z), 1e-6)
safe_z = pt.switch(small, pt.ones_like(z), z)
result_general = h / (2 * safe_z) * pt.tanh(safe_z / 2)
result_zero = h / 4.0
return pt.switch(small, result_zero, result_general)


def mode(h, z):
return continuous_mode(_lower_bound(), _upper_bound(h, z), logpdf, h, z)


def median(h, z):
return ppf(0.5, h, z)


def var(h, z):
z = pt.as_tensor_variable(z)
small = pt.lt(pt.abs(z), 1e-6)
safe_z = pt.switch(small, pt.ones_like(z), z)
result_general = h / (4 * safe_z**3) * (pt.sinh(safe_z) - safe_z) / pt.cosh(safe_z / 2) ** 2
result_zero = h / 24.0
return pt.switch(small, result_zero, result_general)


def std(h, z):
return pt.sqrt(var(h, z))


def skewness(h, z):
return continuous_skewness(_lower_bound(), _upper_bound(h, z), logpdf, h, z)


def kurtosis(h, z):
return continuous_kurtosis(_lower_bound(), _upper_bound(h, z), logpdf, h, z)


def entropy(h, z):
return continuous_entropy(_lower_bound(), _upper_bound(h, z), logpdf, h, z)


def logpdf(x, h, z):
x = pt.as_tensor_variable(x)
log_tilt = h * _log_cosh_half(z) - z**2 * x / 2
result = log_tilt + _log_pg_density_base(x, h)
return pt.switch(pt.le(x, 0), -pt.inf, result)


def pdf(x, h, z):
return pt.exp(logpdf(x, h, z))


def cdf(x, h, z):
x = pt.as_tensor_variable(x)
n_points = 500
lower = _lower_bound()

t = pt.linspace(lower, x, n_points)
pdf_vals = pdf(t, h, z)
dx = (x - lower) / (n_points - 1)
result = dx * (0.5 * pdf_vals[0] + pt.sum(pdf_vals[1:-1], axis=0) + 0.5 * pdf_vals[-1])

return cdf_bounds(result, x, 0, pt.inf)


def logcdf(x, h, z):
return pt.switch(pt.le(x, 0), -pt.inf, pt.log(cdf(x, h, z)))


def sf(x, h, z):
return 1.0 - cdf(x, h, z)


def logsf(x, h, z):
return pt.log1p(-cdf(x, h, z))


def ppf(q, h, z, max_iter=50, tol=1e-8):
# Use log-normal approximation as initial guess to avoid Newton oscillation.
# PG(h, z) is positive and right-skewed; a log-normal is a good proxy.
m = mean(h, z)
v = var(h, z)
sigma_ln_sq = pt.log1p(v / m**2)
mu_ln = pt.log(m) - sigma_ln_sq / 2
sigma_ln = pt.sqrt(sigma_ln_sq)
x0 = pt.exp(mu_ln + sigma_ln * pt.sqrt(2.0) * pt.erfinv(2 * q - 1))
x0 = pt.maximum(x0, _lower_bound())

lb = _lower_bound()

def step(x_prev):
x_prev_squeezed = pt.squeeze(x_prev)

cdf_val = cdf(x_prev_squeezed, h, z)
f_x = pt.maximum(pdf(x_prev_squeezed, h, z), 1e-10)
delta = (cdf_val - q) / f_x

max_step = pt.maximum(pt.abs(x_prev_squeezed), 0.5)
delta = pt.clip(delta, -max_step, max_step)
x_new = pt.maximum(x_prev_squeezed - delta, lb)

converged = pt.abs(x_new - x_prev_squeezed) < tol
x_new = pt.switch(converged, x_prev_squeezed, x_new)

all_converged = pt.all(converged)
return pt.shape_padleft(x_new), until(all_converged)

x_seq = scan(fn=step, outputs_info=pt.shape_padleft(x0), n_steps=max_iter, return_updates=False)

return ppf_bounds_cont(x_seq[-1].squeeze(), q, 0, pt.inf)


def isf(q, h, z):
return ppf(1.0 - q, h, z)


def rvs(h, z, size=None, random_state=None):
K = 200
k = pt.arange(1, K + 1, dtype="float64")

if size is None:
gamma_size = (K,)
elif isinstance(size, int):
gamma_size = (K, size)
else:
gamma_size = (K, *size)

gamma_draws = pt.random.gamma(h, scale=1.0, size=gamma_size, rng=random_state)

z2_term = z**2 / (4 * pt.pi**2)
k_bc = k.reshape((-1,) + (1,) * (gamma_draws.ndim - 1))
denom = (k_bc - 0.5) ** 2 + z2_term

return pt.sum(gamma_draws / denom, axis=0) / (2 * pt.pi**2)
185 changes: 185 additions & 0 deletions tests/test_polyagamma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
"""Test PolyaGamma distribution against empirical samples."""

import numpy as np
import pytensor.tensor as pt
import pytest
from numpy.testing import assert_allclose
from scipy.integrate import quad
from scipy.stats import kurtosis, skew

from pytensor_distributions import polyagamma as PolyaGamma
from tests.helper_scipy import make_params

PARAMS_LIST = [
(1.0, 0.0),
(2.0, 2.0),
(0.5, 1.0),
]


@pytest.fixture(scope="module")
def samples():
"""Generate samples once for all tests (shared across parametrizations)."""
results = {}
for h, z in PARAMS_LIST:
p_params = make_params(h, z, dtype="float64")
rng_p = pt.random.default_rng(1)
rvs = PolyaGamma.rvs(*p_params, size=10_000, random_state=rng_p).eval()
results[(h, z)] = (p_params, rvs)
return results


@pytest.mark.parametrize("params", PARAMS_LIST)
def test_polyagamma_moments(params, samples):
"""Theoretical moments should match empirical moments from samples."""
p_params, rvs = samples[params]

assert_allclose(PolyaGamma.mean(*p_params).eval(), rvs.mean(), rtol=3e-2, atol=3e-2)
assert_allclose(PolyaGamma.var(*p_params).eval(), rvs.var(), rtol=1e-1, atol=1e-3)
assert_allclose(PolyaGamma.std(*p_params).eval(), rvs.std(), rtol=1e-1, atol=1e-3)
assert_allclose(PolyaGamma.skewness(*p_params).eval(), skew(rvs), rtol=3e-1, atol=1e-2)
assert_allclose(PolyaGamma.kurtosis(*p_params).eval(), kurtosis(rvs), rtol=3e-1, atol=1e-2)


@pytest.mark.parametrize("params", PARAMS_LIST)
def test_polyagamma_cdf(params, samples):
"""CDF should match empirical CDF and be monotonic on a small grid."""
p_params, rvs = samples[params]

sample_x = rvs[:20]
theoretical_cdf = PolyaGamma.cdf(sample_x, *p_params).eval()
for i, x in enumerate(sample_x):
empirical_cdf = np.mean(rvs <= x)
assert_allclose(theoretical_cdf[i], empirical_cdf, rtol=1e-1, atol=1e-3)

x_grid = np.linspace(np.percentile(rvs, 1), np.percentile(rvs, 99), 50)
cdf_vals = PolyaGamma.cdf(x_grid, *p_params).eval()
assert np.all(np.diff(cdf_vals) >= -1e-4), "CDF is not monotonic"


@pytest.mark.parametrize("params", PARAMS_LIST)
def test_polyagamma_cdf_bounds(params, samples):
"""CDF should be 0 at lower bound and handle out-of-support."""
p_params, _ = samples[params]

extended_vals = np.array([0.0, np.inf, -1.0, -2.0, np.inf, np.inf])
expected = np.array([0.0, 1.0, 0.0, 0.0, 1.0, 1.0])
assert_allclose(PolyaGamma.cdf(extended_vals, *p_params).eval(), expected)


@pytest.mark.slow
@pytest.mark.parametrize("params", [(1.0, 0.0)])
def test_polyagamma_ppf(params, samples):
"""PPF should match empirical quantiles. Slow due to scan-based solver."""
p_params, rvs = samples[params]

q = np.linspace(0.05, 0.95, 10)
theoretical = PolyaGamma.ppf(q, *p_params).eval()
empirical = np.quantile(rvs, q)
assert_allclose(theoretical, empirical, rtol=1e-1, atol=5e-3)


@pytest.mark.slow
def test_polyagamma_ppf_cdf_inverse(samples):
"""CDF(PPF(q)) should recover q. Slow due to scan-based solver."""
p_params, _ = samples[(1.0, 0.0)]

q = np.array([0.1, 0.5, 0.9])
x_vals = PolyaGamma.ppf(q, *p_params).eval()
recovered = PolyaGamma.cdf(x_vals, *p_params).eval()
assert_allclose(recovered, q, atol=1e-4)


@pytest.mark.parametrize("params", PARAMS_LIST)
def test_polyagamma_pdf(params, samples):
"""PDF should be non-negative and integrate to ~1."""
p_params, rvs = samples[params]

x_grid = np.linspace(0.01, np.percentile(rvs, 99), 100)
pdf_vals = PolyaGamma.pdf(x_grid, *p_params).eval()
assert np.all(pdf_vals >= 0), "PDF has negative values"

u_b = float(np.percentile(rvs, 99.9))
result, _ = quad(lambda x: PolyaGamma.pdf(x, *p_params).eval(), 0, u_b)
assert np.abs(result - 1) < 0.02, f"PDF integral = {result}, should be 1"


@pytest.mark.parametrize("params", PARAMS_LIST)
def test_polyagamma_pdf_cdf_consistency(params, samples):
"""Numerical derivative of CDF should approximate PDF."""
p_params, rvs = samples[params]

x_mid = np.linspace(np.percentile(rvs, 5), np.percentile(rvs, 95), 20)
eps = 1e-5
cdf_plus = PolyaGamma.cdf(x_mid + eps, *p_params).eval()
cdf_minus = PolyaGamma.cdf(x_mid - eps, *p_params).eval()
numerical_pdf = (cdf_plus - cdf_minus) / (2 * eps)
pdf_vals = PolyaGamma.pdf(x_mid, *p_params).eval()

mask = np.abs(pdf_vals) > 1e-4
if np.any(mask):
rel_error = np.abs(numerical_pdf[mask] - pdf_vals[mask]) / (np.abs(pdf_vals[mask]) + 1e-10)
assert np.all(rel_error < 1e-2), (
f"PDF doesn't match CDF derivative. Max rel error: {np.max(rel_error)}"
)


@pytest.mark.parametrize("params", PARAMS_LIST)
def test_polyagamma_sf_complement(params, samples):
"""SF + CDF should equal 1."""
p_params, rvs = samples[params]

x = rvs[:20]
cdf_vals = PolyaGamma.cdf(x, *p_params).eval()
sf_vals = PolyaGamma.sf(x, *p_params).eval()
assert_allclose(cdf_vals + sf_vals, 1.0, atol=1e-4)


@pytest.mark.parametrize("params", PARAMS_LIST)
def test_polyagamma_entropy(params, samples):
"""Monte Carlo entropy should match computed entropy."""
p_params, rvs = samples[params]

logpdf_vals = PolyaGamma.logpdf(rvs, *p_params).eval()
logpdf_vals = logpdf_vals[np.isfinite(logpdf_vals)]
mc_entropy = -np.mean(logpdf_vals)
computed_entropy = PolyaGamma.entropy(*p_params).eval()

rel_error = np.abs(mc_entropy - computed_entropy) / (np.abs(computed_entropy) + 1e-10)
assert rel_error < 0.1, f"Entropy mismatch. MC: {mc_entropy}, Computed: {computed_entropy}"


def test_polyagamma_mean_z_zero():
"""Mean at z=0 should equal h/4."""
p_params = make_params(2.0, 0.0, dtype="float64")
result = PolyaGamma.mean(*p_params).eval()
assert_allclose(result, 0.5, rtol=1e-10)


def test_polyagamma_var_z_zero():
"""Variance at z=0 should equal h/24."""
p_params = make_params(2.0, 0.0, dtype="float64")
result = PolyaGamma.var(*p_params).eval()
assert_allclose(result, 2.0 / 24, rtol=1e-10)


def test_polyagamma_pdf_positive():
"""PDF should be positive on the support."""
p_params = make_params(1.0, 1.0, dtype="float64")
x = np.linspace(0.01, 2.0, 100)
pdf_vals = PolyaGamma.pdf(x, *p_params).eval()
assert np.all(pdf_vals > 0)


def test_polyagamma_pdf_zero_outside():
"""PDF should be zero for x <= 0."""
p_params = make_params(1.0, 1.0, dtype="float64")
assert_allclose(PolyaGamma.pdf(-1.0, *p_params).eval(), 0.0)
assert_allclose(PolyaGamma.pdf(0.0, *p_params).eval(), 0.0)


def test_polyagamma_logpdf_neginf_outside():
"""Logpdf should be -inf for x <= 0."""
p_params = make_params(1.0, 1.0, dtype="float64")
assert PolyaGamma.logpdf(-1.0, *p_params).eval() == -np.inf
assert PolyaGamma.logpdf(0.0, *p_params).eval() == -np.inf
Loading