From c80e4a83b3874a2e067ffdd9c13c47b4e32874e0 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Fri, 6 Feb 2026 09:17:05 +0200 Subject: [PATCH 1/3] add mvnormal --- distributions/mvnormal.py | 86 +++++++++++++++++++++++++++ tests/test_mvnormal.py | 120 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 206 insertions(+) create mode 100644 distributions/mvnormal.py create mode 100644 tests/test_mvnormal.py diff --git a/distributions/mvnormal.py b/distributions/mvnormal.py new file mode 100644 index 0000000..b3ccc2d --- /dev/null +++ b/distributions/mvnormal.py @@ -0,0 +1,86 @@ +from functools import partial + +import numpy as np +import pytensor.tensor as pt +from pytensor.tensor.linalg import solve_triangular + +solve_lower = partial(solve_triangular, lower=True) + + +def _logdet_from_cholesky(chol): + """Compute log determinant from Cholesky factor and check positive definiteness.""" + diag = pt.diagonal(chol, axis1=-2, axis2=-1) + logdet = pt.sum(pt.log(diag), axis=-1) * 2 + posdef = pt.all(diag > 0, axis=-1) + return logdet, posdef + + +def quaddist_chol(value, mu, cov): + """Compute (x - mu).T @ Sigma^-1 @ (x - mu) and the logdet of Sigma.""" + if value.ndim == 0: + raise ValueError("Value can't be a scalar") + if value.ndim == 1: + onedim = True + value = value[None, :] + else: + onedim = False + + chol_cov = pt.linalg.cholesky(cov, lower=True) + logdet, posdef = _logdet_from_cholesky(chol_cov) + + delta = value - mu + delta_trans = solve_lower(chol_cov, delta, b_ndim=1) + quaddist = (delta_trans**2).sum(axis=-1) + + if onedim: + return quaddist[0], logdet, posdef + else: + return quaddist, logdet, posdef + + +def mean(mu, cov): + return mu + + +def mode(mu, cov): + return mu + + +def median(mu, cov): + return mu + + +def var(mu, cov): + return pt.diagonal(cov, axis1=-2, axis2=-1) + + +def std(mu, cov): + return pt.sqrt(var(mu, cov)) + + +def skewness(mu, cov): + return pt.zeros_like(mu) + + +def kurtosis(mu, cov): + return pt.zeros_like(mu) + + +def entropy(mu, cov): + k = cov.shape[-1] + _, logdet = pt.linalg.slogdet(cov) + return 0.5 * (k * pt.log(2 * pt.pi * pt.e) + logdet) + + +def pdf(x, mu, cov): + return pt.exp(logpdf(x, mu, cov)) + + +def logpdf(x, mu, cov): + quaddist, logdet, _ = quaddist_chol(x, mu, cov) + k = pt.as_tensor(x.shape[-1], dtype="floatX") + return -0.5 * (k * pt.log(2 * np.pi) + logdet + quaddist) + + +def rvs(mu, cov, size=None, random_state=None): + return pt.random.multivariate_normal(mu, cov, size=size, rng=random_state) diff --git a/tests/test_mvnormal.py b/tests/test_mvnormal.py new file mode 100644 index 0000000..88c6aa4 --- /dev/null +++ b/tests/test_mvnormal.py @@ -0,0 +1,120 @@ +"""Test Multivariate Normal distribution.""" + +import numpy as np +import pytensor.tensor as pt +import pytest +from numpy.testing import assert_allclose +from scipy.stats import multivariate_normal + +from distributions import mvnormal as MvNormal + +# Test parameters for multivariate normal +TEST_CASES = [ + (np.array([0.0, 0.0]), np.eye(2)), + (np.array([1.0, -1.0]), np.array([[2.0, 0.0], [0.0, 0.5]])), + (np.array([0.0, 0.0, 0.0]), np.eye(3)), + ( + np.array([1.0, 2.0, 3.0]), + np.array([[1.0, 0.5, 0.2], [0.5, 2.0, -0.3], [0.2, -0.3, 0.8]]), + ), +] + + +@pytest.mark.parametrize("mu, cov", TEST_CASES) +def test_mvnormal_logpdf(mu, cov): + scipy_dist = multivariate_normal(mean=mu, cov=cov) + + p_mu = pt.constant(mu) + p_cov = pt.constant(cov) + + actual = MvNormal.logpdf(mu, p_mu, p_cov).eval() + expected = scipy_dist.logpdf(mu) + assert_allclose(actual, expected, rtol=1e-3, err_msg=f"logpdf at mean failed for mu={mu}") + + x_samples = scipy_dist.rvs(size=10, random_state=814) + actual = MvNormal.logpdf(x_samples, p_mu, p_cov).eval() + expected = scipy_dist.logpdf(x_samples) + assert_allclose(actual, expected, rtol=1e-5, err_msg=f"logpdf at samples failed for mu={mu}") + + +@pytest.mark.parametrize("mu, cov", TEST_CASES) +def test_mvnormal_pdf(mu, cov): + scipy_dist = multivariate_normal(mean=mu, cov=cov) + + p_mu = pt.constant(mu) + p_cov = pt.constant(cov) + + actual = MvNormal.pdf(mu, p_mu, p_cov).eval() + expected = scipy_dist.pdf(mu) + assert_allclose(actual, expected, rtol=1e-5, err_msg="pdf at mean failed") + + +@pytest.mark.parametrize("mu, cov", TEST_CASES) +def test_mvnormal_moments(mu, cov): + p_mu = pt.constant(mu) + p_cov = pt.constant(cov) + + actual = MvNormal.mean(p_mu, p_cov).eval() + assert_allclose(actual, mu, rtol=1e-10, err_msg="Mean should equal mu") + + actual = MvNormal.mode(p_mu, p_cov).eval() + assert_allclose(actual, mu, rtol=1e-10, err_msg="Mode should equal mu") + + actual = MvNormal.median(p_mu, p_cov).eval() + assert_allclose(actual, mu, rtol=1e-10, err_msg="Median should equal mu") + + actual = MvNormal.skewness(p_mu, p_cov).eval() + expected = np.zeros_like(mu) + assert_allclose(actual, expected, atol=1e-10, err_msg="Skewness should be zero") + + actual = MvNormal.kurtosis(p_mu, p_cov).eval() + expected = np.zeros_like(mu) + assert_allclose(actual, expected, atol=1e-10, err_msg="Kurtosis should be zero") + + +@pytest.mark.parametrize("mu, cov", TEST_CASES) +def test_mvnormal_var(mu, cov): + p_mu = pt.constant(mu) + p_cov = pt.constant(cov) + + actual = MvNormal.var(p_mu, p_cov).eval() + expected = np.diagonal(cov) + assert_allclose(actual, expected, rtol=1e-10, err_msg="Variance should equal diagonal of cov") + + +@pytest.mark.parametrize("mu, cov", TEST_CASES) +def test_mvnormal_std(mu, cov): + """Test standard deviation.""" + p_mu = pt.constant(mu) + p_cov = pt.constant(cov) + + actual = MvNormal.std(p_mu, p_cov).eval() + expected = np.sqrt(np.diagonal(cov)) + assert_allclose(actual, expected, rtol=1e-10, err_msg="Std should equal sqrt of diagonal") + + +@pytest.mark.parametrize("mu, cov", TEST_CASES) +def test_mvnormal_entropy(mu, cov): + scipy_dist = multivariate_normal(mean=mu, cov=cov) + + p_mu = pt.constant(mu) + p_cov = pt.constant(cov) + + actual = MvNormal.entropy(p_mu, p_cov).eval() + expected = scipy_dist.entropy() + assert_allclose(actual, expected, rtol=1e-5, err_msg="Entropy test failed") + + +@pytest.mark.parametrize("mu, cov", TEST_CASES) +def test_mvnormal_rvs(mu, cov): + p_mu = pt.constant(mu) + p_cov = pt.constant(cov) + rng = pt.random.default_rng(205) + + samples = MvNormal.rvs(p_mu, p_cov, size=10000, random_state=rng).eval() + + assert samples.shape == (10000, len(mu)), f"Shape mismatch: got {samples.shape}" + assert_allclose(samples.mean(axis=0), mu, atol=0.2, err_msg="Sample mean should be close to mu") + assert_allclose( + np.cov(samples.T), cov, rtol=0.2, atol=0.1, err_msg="Sample cov should be close to cov" + ) From bcde039291d40b8e4897e9e48c70f047ee7baf5a Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Fri, 6 Feb 2026 09:46:38 +0200 Subject: [PATCH 2/3] broadcast --- distributions/mvnormal.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/distributions/mvnormal.py b/distributions/mvnormal.py index b3ccc2d..d36e16f 100644 --- a/distributions/mvnormal.py +++ b/distributions/mvnormal.py @@ -39,15 +39,15 @@ def quaddist_chol(value, mu, cov): def mean(mu, cov): - return mu + return pt.broadcast_to(mu, cov.shape[:-1]) def mode(mu, cov): - return mu + return pt.broadcast_to(mu, cov.shape[:-1]) def median(mu, cov): - return mu + return pt.broadcast_to(mu, cov.shape[:-1]) def var(mu, cov): @@ -59,10 +59,12 @@ def std(mu, cov): def skewness(mu, cov): + mu = pt.broadcast_to(mu, cov.shape[:-1]) return pt.zeros_like(mu) def kurtosis(mu, cov): + mu = pt.broadcast_to(mu, cov.shape[:-1]) return pt.zeros_like(mu) @@ -83,4 +85,5 @@ def logpdf(x, mu, cov): def rvs(mu, cov, size=None, random_state=None): + mu = pt.broadcast_to(mu, cov.shape[:-1]) return pt.random.multivariate_normal(mu, cov, size=size, rng=random_state) From 468d014b47ef77d9f9f7caffa81ba5a917706e4f Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Fri, 6 Feb 2026 10:18:41 +0200 Subject: [PATCH 3/3] drop numpy --- distributions/mvnormal.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/distributions/mvnormal.py b/distributions/mvnormal.py index d36e16f..aa0b343 100644 --- a/distributions/mvnormal.py +++ b/distributions/mvnormal.py @@ -1,6 +1,5 @@ from functools import partial -import numpy as np import pytensor.tensor as pt from pytensor.tensor.linalg import solve_triangular @@ -81,7 +80,7 @@ def pdf(x, mu, cov): def logpdf(x, mu, cov): quaddist, logdet, _ = quaddist_chol(x, mu, cov) k = pt.as_tensor(x.shape[-1], dtype="floatX") - return -0.5 * (k * pt.log(2 * np.pi) + logdet + quaddist) + return -0.5 * (k * pt.log(2 * pt.pi) + logdet + quaddist) def rvs(mu, cov, size=None, random_state=None):