Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions distributions/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,83 @@ def discrete_kurtosis(ppf, pdf, *params):
return result


def continuous_moment(lower, upper, logpdf, *params, order=1, mean_val=None, n_points=1000):
"""
Compute raw or central moments for continuous distributions using numerical integration.

Uses the trapezoidal rule for numerical integration.

Parameters
----------
lower : float
Lower bound for integration
upper : float
Upper bound for integration
logpdf : function
Log probability density function that takes (x, *params) as arguments
*params : tensor variables
Distribution parameters to pass to logpdf
order : int
Order of the moment to compute
mean_val : tensor, optional
If provided, computes central moment around this mean.
If None, computes raw moment.
n_points : int
Number of integration points

Returns
-------
moment : tensor
"""
if len(params) == 1:
broadcast_shape = pt.as_tensor_variable(params[0])
else:
broadcast_shape = pt.broadcast_arrays(*params)[0]

x_vals = pt.linspace(lower, upper, n_points)
x_broadcast = x_vals.reshape((-1,) + (1,) * broadcast_shape.ndim)
pdf_vals = pt.exp(logpdf(x_broadcast, *params))

if mean_val is not None:
# Central moment
integrand = (x_broadcast - mean_val) ** order * pdf_vals
else:
# Raw moment
integrand = x_broadcast**order * pdf_vals

dx = (upper - lower) / (n_points - 1)
result = dx * (0.5 * integrand[0] + pt.sum(integrand[1:-1], axis=0) + 0.5 * integrand[-1])

return pt.squeeze(result) if broadcast_shape.ndim == 0 else result


def continuous_mean(lower, upper, logpdf, *params):
"""Compute mean for continuous distributions."""
return continuous_moment(lower, upper, logpdf, *params, order=1)


def continuous_variance(lower, upper, logpdf, *params):
"""Compute variance for continuous distributions."""
mean_val = continuous_moment(lower, upper, logpdf, *params, order=1)
return continuous_moment(lower, upper, logpdf, *params, order=2, mean_val=mean_val)


def continuous_skewness(lower, upper, logpdf, *params):
"""Compute skewness for continuous distributions."""
mean_val = continuous_moment(lower, upper, logpdf, *params, order=1)
variance = continuous_moment(lower, upper, logpdf, *params, order=2, mean_val=mean_val)
third_central = continuous_moment(lower, upper, logpdf, *params, order=3, mean_val=mean_val)
return third_central / (pt.sqrt(variance) ** 3)


def continuous_kurtosis(lower, upper, logpdf, *params):
"""Compute excess kurtosis for continuous distributions."""
mean_val = continuous_moment(lower, upper, logpdf, *params, order=1)
variance = continuous_moment(lower, upper, logpdf, *params, order=2, mean_val=mean_val)
fourth_central = continuous_moment(lower, upper, logpdf, *params, order=4, mean_val=mean_val)
return fourth_central / (variance**2) - 3


def from_tau(tau):
"""Convert precision (tau) to standard deviation (sigma)."""
sigma = 1 / pt.sqrt(tau)
Expand Down
132 changes: 132 additions & 0 deletions distributions/logitnormal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import pytensor.tensor as pt

from distributions.helper import (
cdf_bounds,
continuous_entropy,
continuous_kurtosis,
continuous_mean,
continuous_skewness,
continuous_variance,
ppf_bounds_cont,
)
from distributions.normal import ppf as normal_ppf

# Support bounds for logitnormal (open interval (0, 1))
_LOWER = 0.001
_UPPER = 0.999


def _logit(x):
return pt.log(x) - pt.log1p(-x)


def _expit(y):
return pt.sigmoid(y)


def mean(mu, sigma):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In helper.py, we have functions to numerically compute moments for discrete distributions. We should move this and other methods there, so we can reuse them for other distributions. In preliz, distributions has a xvals method https://github.com/arviz-devs/preliz/blob/28bbd018963cbc010d3f13e62124eb4653ec1459/preliz/distributions/distributions.py#L507 that we used for plotting or, in this cas,e to get a reasonable range of values to evaluate some functions. We could have something similar here.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ported a few things over to helper.py. Let me know what you think.

return continuous_mean(_LOWER, _UPPER, logpdf, mu, sigma)


def mode(mu, sigma):
return _expit(mu)


def median(mu, sigma):
shape = pt.broadcast_arrays(mu, sigma)[0]
return pt.full_like(shape, _expit(mu))


def var(mu, sigma):
return continuous_variance(_LOWER, _UPPER, logpdf, mu, sigma)


def std(mu, sigma):
return pt.sqrt(var(mu, sigma))


def skewness(mu, sigma):
return continuous_skewness(_LOWER, _UPPER, logpdf, mu, sigma)


def kurtosis(mu, sigma):
return continuous_kurtosis(_LOWER, _UPPER, logpdf, mu, sigma)


def entropy(mu, sigma):
return continuous_entropy(_LOWER, _UPPER, logpdf, mu, sigma)


def pdf(x, mu, sigma):
return pt.exp(logpdf(x, mu, sigma))


def logpdf(x, mu, sigma):
logit_x = _logit(x)
return pt.switch(
pt.or_(pt.le(x, 0), pt.ge(x, 1)),
-pt.inf,
-0.5 * ((logit_x - mu) / sigma) ** 2
- pt.log(sigma)
- 0.5 * pt.log(2 * pt.pi)
- pt.log(x)
- pt.log1p(-x),
)


def cdf(x, mu, sigma):
logit_x = _logit(x)
prob = 0.5 * (1 + pt.erf((logit_x - mu) / (sigma * pt.sqrt(2))))
return cdf_bounds(prob, x, 0, 1)


def logcdf(x, mu, sigma):
logit_x = _logit(x)
z = (logit_x - mu) / sigma
return pt.switch(
pt.le(x, 0),
-pt.inf,
pt.switch(
pt.ge(x, 1),
0.0,
pt.switch(
pt.lt(z, -1.0),
pt.log(pt.erfcx(-z / pt.sqrt(2.0)) / 2.0) - pt.sqr(z) / 2.0,
pt.log1p(-pt.erfc(z / pt.sqrt(2.0)) / 2.0),
),
),
)


def sf(x, mu, sigma):
return pt.exp(logsf(x, mu, sigma))


def logsf(x, mu, sigma):
logit_x = _logit(x)
z = (logit_x - mu) / sigma
return pt.switch(
pt.le(x, 0),
0.0,
pt.switch(
pt.ge(x, 1),
-pt.inf,
pt.switch(
pt.gt(z, 1.0),
pt.log(pt.erfcx(z / pt.sqrt(2.0)) / 2.0) - pt.sqr(z) / 2.0,
pt.log1p(-0.5 * (1 + pt.erf(z / pt.sqrt(2.0)))),
),
),
)


def ppf(q, mu, sigma):
return ppf_bounds_cont(_expit(normal_ppf(q, mu, sigma)), q, 0, 1)


def isf(q, mu, sigma):
return ppf(1 - q, mu, sigma)


def rvs(mu, sigma, size=None, random_state=None):
return _expit(pt.random.normal(mu, sigma, rng=random_state, size=size))
39 changes: 39 additions & 0 deletions tests/test_logitnormal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Test Logit-Normal distribution against empirical samples."""

import pytest

from distributions import logitnormal as LogitNormal
from tests.helper_empirical import run_empirical_tests
from tests.helper_scipy import make_params


@pytest.mark.parametrize(
"params",
[
[0.0, 1.0], # Standard logit-normal (centered)
[0.0, 0.001], # Narrower distribution
[1.0, 1.0], # Shifted right (mode > 0.5)
[-1.0, 1.0], # Shifted left (mode < 0.5)
[0.0, 2.0], # Wider distribution (approaches U-shape)
[2.0, 0.5], # Strongly shifted right
],
)
def test_logitnormal_vs_random(params):
"""Test Logit-Normal distribution against random samples."""
p_params = make_params(*params, dtype="float64")
support = (0, 1)

run_empirical_tests(
p_dist=LogitNormal,
p_params=p_params,
support=support,
name="logitnormal",
sample_size=500_000,
mean_rtol=1e-2,
var_rtol=1e-2,
std_rtol=1e-2,
skewness_rtol=2e-1,
kurtosis_rtol=2e-1,
quantiles_rtol=3e-2,
cdf_rtol=5e-2,
)