Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions distributions/mvnormal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from functools import partial

import pytensor.tensor as pt
from pytensor.tensor.linalg import solve_triangular

solve_lower = partial(solve_triangular, lower=True)


def _logdet_from_cholesky(chol):
"""Compute log determinant from Cholesky factor and check positive definiteness."""
diag = pt.diagonal(chol, axis1=-2, axis2=-1)
logdet = pt.sum(pt.log(diag), axis=-1) * 2
posdef = pt.all(diag > 0, axis=-1)
return logdet, posdef


def quaddist_chol(value, mu, cov):
"""Compute (x - mu).T @ Sigma^-1 @ (x - mu) and the logdet of Sigma."""
if value.ndim == 0:
raise ValueError("Value can't be a scalar")
if value.ndim == 1:
onedim = True
value = value[None, :]
else:
onedim = False

chol_cov = pt.linalg.cholesky(cov, lower=True)
logdet, posdef = _logdet_from_cholesky(chol_cov)

delta = value - mu
delta_trans = solve_lower(chol_cov, delta, b_ndim=1)
quaddist = (delta_trans**2).sum(axis=-1)

if onedim:
return quaddist[0], logdet, posdef
else:
return quaddist, logdet, posdef


def mean(mu, cov):
return pt.broadcast_to(mu, cov.shape[:-1])


def mode(mu, cov):
return pt.broadcast_to(mu, cov.shape[:-1])


def median(mu, cov):
return pt.broadcast_to(mu, cov.shape[:-1])


def var(mu, cov):
return pt.diagonal(cov, axis1=-2, axis2=-1)


def std(mu, cov):
return pt.sqrt(var(mu, cov))


def skewness(mu, cov):
mu = pt.broadcast_to(mu, cov.shape[:-1])
return pt.zeros_like(mu)


def kurtosis(mu, cov):
mu = pt.broadcast_to(mu, cov.shape[:-1])
return pt.zeros_like(mu)


def entropy(mu, cov):
k = cov.shape[-1]
_, logdet = pt.linalg.slogdet(cov)
return 0.5 * (k * pt.log(2 * pt.pi * pt.e) + logdet)


def pdf(x, mu, cov):
return pt.exp(logpdf(x, mu, cov))


def logpdf(x, mu, cov):
quaddist, logdet, _ = quaddist_chol(x, mu, cov)
k = pt.as_tensor(x.shape[-1], dtype="floatX")
return -0.5 * (k * pt.log(2 * pt.pi) + logdet + quaddist)


def rvs(mu, cov, size=None, random_state=None):
mu = pt.broadcast_to(mu, cov.shape[:-1])
return pt.random.multivariate_normal(mu, cov, size=size, rng=random_state)
120 changes: 120 additions & 0 deletions tests/test_mvnormal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""Test Multivariate Normal distribution."""

import numpy as np
import pytensor.tensor as pt
import pytest
from numpy.testing import assert_allclose
from scipy.stats import multivariate_normal

from distributions import mvnormal as MvNormal

# Test parameters for multivariate normal
TEST_CASES = [
(np.array([0.0, 0.0]), np.eye(2)),
(np.array([1.0, -1.0]), np.array([[2.0, 0.0], [0.0, 0.5]])),
(np.array([0.0, 0.0, 0.0]), np.eye(3)),
(
np.array([1.0, 2.0, 3.0]),
np.array([[1.0, 0.5, 0.2], [0.5, 2.0, -0.3], [0.2, -0.3, 0.8]]),
),
]


@pytest.mark.parametrize("mu, cov", TEST_CASES)
def test_mvnormal_logpdf(mu, cov):
scipy_dist = multivariate_normal(mean=mu, cov=cov)

p_mu = pt.constant(mu)
p_cov = pt.constant(cov)

actual = MvNormal.logpdf(mu, p_mu, p_cov).eval()
expected = scipy_dist.logpdf(mu)
assert_allclose(actual, expected, rtol=1e-3, err_msg=f"logpdf at mean failed for mu={mu}")

x_samples = scipy_dist.rvs(size=10, random_state=814)
actual = MvNormal.logpdf(x_samples, p_mu, p_cov).eval()
expected = scipy_dist.logpdf(x_samples)
assert_allclose(actual, expected, rtol=1e-5, err_msg=f"logpdf at samples failed for mu={mu}")


@pytest.mark.parametrize("mu, cov", TEST_CASES)
def test_mvnormal_pdf(mu, cov):
scipy_dist = multivariate_normal(mean=mu, cov=cov)

p_mu = pt.constant(mu)
p_cov = pt.constant(cov)

actual = MvNormal.pdf(mu, p_mu, p_cov).eval()
expected = scipy_dist.pdf(mu)
assert_allclose(actual, expected, rtol=1e-5, err_msg="pdf at mean failed")


@pytest.mark.parametrize("mu, cov", TEST_CASES)
def test_mvnormal_moments(mu, cov):
p_mu = pt.constant(mu)
p_cov = pt.constant(cov)

actual = MvNormal.mean(p_mu, p_cov).eval()
assert_allclose(actual, mu, rtol=1e-10, err_msg="Mean should equal mu")

actual = MvNormal.mode(p_mu, p_cov).eval()
assert_allclose(actual, mu, rtol=1e-10, err_msg="Mode should equal mu")

actual = MvNormal.median(p_mu, p_cov).eval()
assert_allclose(actual, mu, rtol=1e-10, err_msg="Median should equal mu")

actual = MvNormal.skewness(p_mu, p_cov).eval()
expected = np.zeros_like(mu)
assert_allclose(actual, expected, atol=1e-10, err_msg="Skewness should be zero")

actual = MvNormal.kurtosis(p_mu, p_cov).eval()
expected = np.zeros_like(mu)
assert_allclose(actual, expected, atol=1e-10, err_msg="Kurtosis should be zero")


@pytest.mark.parametrize("mu, cov", TEST_CASES)
def test_mvnormal_var(mu, cov):
p_mu = pt.constant(mu)
p_cov = pt.constant(cov)

actual = MvNormal.var(p_mu, p_cov).eval()
expected = np.diagonal(cov)
assert_allclose(actual, expected, rtol=1e-10, err_msg="Variance should equal diagonal of cov")


@pytest.mark.parametrize("mu, cov", TEST_CASES)
def test_mvnormal_std(mu, cov):
"""Test standard deviation."""
p_mu = pt.constant(mu)
p_cov = pt.constant(cov)

actual = MvNormal.std(p_mu, p_cov).eval()
expected = np.sqrt(np.diagonal(cov))
assert_allclose(actual, expected, rtol=1e-10, err_msg="Std should equal sqrt of diagonal")


@pytest.mark.parametrize("mu, cov", TEST_CASES)
def test_mvnormal_entropy(mu, cov):
scipy_dist = multivariate_normal(mean=mu, cov=cov)

p_mu = pt.constant(mu)
p_cov = pt.constant(cov)

actual = MvNormal.entropy(p_mu, p_cov).eval()
expected = scipy_dist.entropy()
assert_allclose(actual, expected, rtol=1e-5, err_msg="Entropy test failed")


@pytest.mark.parametrize("mu, cov", TEST_CASES)
def test_mvnormal_rvs(mu, cov):
p_mu = pt.constant(mu)
p_cov = pt.constant(cov)
rng = pt.random.default_rng(205)

samples = MvNormal.rvs(p_mu, p_cov, size=10000, random_state=rng).eval()

assert samples.shape == (10000, len(mu)), f"Shape mismatch: got {samples.shape}"
assert_allclose(samples.mean(axis=0), mu, atol=0.2, err_msg="Sample mean should be close to mu")
assert_allclose(
np.cov(samples.T), cov, rtol=0.2, atol=0.1, err_msg="Sample cov should be close to cov"
)
Loading