From 20ec72761fe203e987f55278eaf09dfbaa81a998 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Mon, 23 Feb 2026 12:59:17 -0600 Subject: [PATCH 1/3] Draft matrix normal dist --- pytensor_distributions/matrixnormal.py | 106 ++++++++++++++++++++++++ tests/test_matrixnormal.py | 110 +++++++++++++++++++++++++ 2 files changed, 216 insertions(+) create mode 100644 pytensor_distributions/matrixnormal.py create mode 100644 tests/test_matrixnormal.py diff --git a/pytensor_distributions/matrixnormal.py b/pytensor_distributions/matrixnormal.py new file mode 100644 index 0000000..f107cc2 --- /dev/null +++ b/pytensor_distributions/matrixnormal.py @@ -0,0 +1,106 @@ +"""Matrix Normal distribution.""" + +from functools import partial + +import pytensor.tensor as pt +from pytensor.tensor.linalg import solve_triangular + +from pytensor_distributions.mvnormal import _logdet_from_cholesky + +solve_lower = partial(solve_triangular, lower=True) + + +def mean(mu, rowcov, colcov): + return pt.as_tensor(mu) + + +def mode(mu, rowcov, colcov): + return pt.as_tensor(mu) + + +def median(mu, rowcov, colcov): + return pt.as_tensor(mu) + + +def var(mu, rowcov, colcov): + rowcov = pt.as_tensor(rowcov) + colcov = pt.as_tensor(colcov) + row_diag = pt.diagonal(rowcov, axis1=-2, axis2=-1) + col_diag = pt.diagonal(colcov, axis1=-2, axis2=-1) + return pt.outer(row_diag, col_diag) + + +def std(mu, rowcov, colcov): + return pt.sqrt(var(mu, rowcov, colcov)) + + +def skewness(mu, rowcov, colcov): + return pt.zeros_like(pt.as_tensor(mu)) + + +def kurtosis(mu, rowcov, colcov): + return pt.zeros_like(pt.as_tensor(mu)) + + +def entropy(mu, rowcov, colcov): + mu = pt.as_tensor(mu) + rowcov = pt.as_tensor(rowcov) + colcov = pt.as_tensor(colcov) + m = mu.shape[-2] + n = mu.shape[-1] + mn = m * n + _, logdet_U = pt.linalg.slogdet(rowcov) + _, logdet_V = pt.linalg.slogdet(colcov) + return 0.5 * mn * pt.log(2 * pt.pi * pt.e) + 0.5 * n * logdet_U + 0.5 * m * logdet_V + + +def logpdf(X, mu, rowcov, colcov): + X = pt.as_tensor(X) + mu = pt.as_tensor(mu) + rowcov = pt.as_tensor(rowcov) + colcov = pt.as_tensor(colcov) + + m = mu.shape[-2] + n = mu.shape[-1] + mn = m * n + + chol_row = pt.linalg.cholesky(rowcov, lower=True) + chol_col = pt.linalg.cholesky(colcov, lower=True) + + logdet_row, _ = _logdet_from_cholesky(chol_row) + logdet_col, _ = _logdet_from_cholesky(chol_col) + + delta = X - mu + + # Compute tr[V^-1 (X-M)^T U^-1 (X-M)] via Cholesky solves + # Using vec identity: quadform = ||L_U^-1 delta L_V^-T||^2_F + Y = solve_lower(chol_row, delta) # L_U^-1 delta, shape (m, n) + Z = solve_lower(chol_col, Y.T) # L_V^-1 (L_U^-1 delta)^T, shape (n, m) + quadform = pt.sum(Z**2) + + log_norm = 0.5 * mn * pt.log(2 * pt.pi) + 0.5 * n * logdet_row + 0.5 * m * logdet_col + + return -0.5 * quadform - log_norm + + +def pdf(X, mu, rowcov, colcov): + return pt.exp(logpdf(X, mu, rowcov, colcov)) + + +def rvs(mu, rowcov, colcov, size=None, random_state=None): + mu = pt.as_tensor(mu) + rowcov = pt.as_tensor(rowcov) + colcov = pt.as_tensor(colcov) + + m = mu.shape[-2] + n = mu.shape[-1] + + L_row = pt.linalg.cholesky(rowcov, lower=True) + L_col = pt.linalg.cholesky(colcov, lower=True) + + if size is None: + Z = pt.random.normal(0, 1, size=(m, n), rng=random_state) + return mu + L_row @ Z @ L_col.T + else: + Z = pt.random.normal(0, 1, size=(size, m, n), rng=random_state) + return mu + L_row @ Z @ L_col.T diff --git a/tests/test_matrixnormal.py b/tests/test_matrixnormal.py new file mode 100644 index 0000000..b20cb3c --- /dev/null +++ b/tests/test_matrixnormal.py @@ -0,0 +1,110 @@ +"""Test Matrix Normal distribution.""" + +import numpy as np +import pytensor.tensor as pt +import pytest +from numpy.testing import assert_allclose +from scipy.stats import matrix_normal as scipy_matrix_normal + +from pytensor_distributions import matrixnormal as MatrixNormal + +# Test cases: (M, U, V) +TEST_CASES = [ + # 2x2 with identity covariances + ( + np.array([[1.0, 2.0], [3.0, 4.0]]), + np.eye(2), + np.eye(2), + ), + # 2x3 with non-trivial covariances + ( + np.array([[1.0, 0.0, -1.0], [2.0, 1.0, 0.0]]), + np.array([[2.0, 0.5], [0.5, 1.0]]), + np.array([[1.0, 0.3, 0.1], [0.3, 2.0, 0.4], [0.1, 0.4, 1.5]]), + ), + # 3x2 with correlated covariances + ( + np.zeros((3, 2)), + np.array([[3.0, 1.0, 0.5], [1.0, 2.0, 0.3], [0.5, 0.3, 1.5]]), + np.array([[1.0, 0.7], [0.7, 2.0]]), + ), +] + + +@pytest.mark.parametrize("M, U, V", TEST_CASES) +def test_matrixnormal_mean(M, U, V): + actual = MatrixNormal.mean(pt.constant(M), pt.constant(U), pt.constant(V)).eval() + assert_allclose(actual, M, rtol=1e-10) + + +@pytest.mark.parametrize("M, U, V", TEST_CASES) +def test_matrixnormal_mode(M, U, V): + actual = MatrixNormal.mode(pt.constant(M), pt.constant(U), pt.constant(V)).eval() + assert_allclose(actual, M, rtol=1e-10) + + +@pytest.mark.parametrize("M, U, V", TEST_CASES) +def test_matrixnormal_logpdf(M, U, V): + scipy_dist = scipy_matrix_normal(mean=M, rowcov=U, colcov=V) + X = M + 0.1 * np.ones_like(M) + + actual = MatrixNormal.logpdf( + pt.constant(X), pt.constant(M), pt.constant(U), pt.constant(V) + ).eval() + expected = scipy_dist.logpdf(X) + assert_allclose(actual, expected, rtol=1e-5, err_msg="logpdf should match scipy") + + +@pytest.mark.parametrize("M, U, V", TEST_CASES) +def test_matrixnormal_pdf(M, U, V): + scipy_dist = scipy_matrix_normal(mean=M, rowcov=U, colcov=V) + X = M + 0.1 * np.ones_like(M) + + actual = MatrixNormal.pdf(pt.constant(X), pt.constant(M), pt.constant(U), pt.constant(V)).eval() + expected = scipy_dist.pdf(X) + assert_allclose(actual, expected, rtol=1e-5, err_msg="pdf should match scipy") + + +@pytest.mark.parametrize("M, U, V", TEST_CASES) +def test_matrixnormal_entropy(M, U, V): + scipy_dist = scipy_matrix_normal(mean=M, rowcov=U, colcov=V) + + actual = MatrixNormal.entropy(pt.constant(M), pt.constant(U), pt.constant(V)).eval() + expected = scipy_dist.entropy() + assert_allclose(actual, expected, rtol=1e-5, err_msg="entropy should match scipy") + + +@pytest.mark.parametrize("M, U, V", TEST_CASES) +def test_matrixnormal_var(M, U, V): + actual = MatrixNormal.var(pt.constant(M), pt.constant(U), pt.constant(V)).eval() + expected = np.outer(np.diag(U), np.diag(V)) + assert_allclose(actual, expected, rtol=1e-10, err_msg="var should be outer product of diags") + + +@pytest.mark.parametrize("M, U, V", TEST_CASES) +def test_matrixnormal_rvs(M, U, V): + m, n = M.shape + + sample = MatrixNormal.rvs(pt.constant(M), pt.constant(U), pt.constant(V), size=None).eval() + assert sample.shape == (m, n), f"Single sample should have shape ({m}, {n})" + + n_samples = 2000 + samples = MatrixNormal.rvs( + pt.constant(M), pt.constant(U), pt.constant(V), size=n_samples + ).eval() + assert samples.shape == (n_samples, m, n) + + sample_mean = np.mean(samples, axis=0) + assert_allclose( + sample_mean, M, atol=0.2, err_msg="Sample mean should approximate theoretical mean" + ) + + sample_var = np.var(samples, axis=0) + theoretical_var = np.outer(np.diag(U), np.diag(V)) + assert_allclose( + sample_var, + theoretical_var, + rtol=0.3, + atol=0.2, + err_msg="Sample variance should approximate theoretical variance", + ) From 68c2b1feeb61526aa23107a0a384fdfe22e661c0 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Mon, 23 Feb 2026 13:21:14 -0600 Subject: [PATCH 2/3] Cleanup --- pytensor_distributions/matrixnormal.py | 50 ++++++++---- tests/test_matrixnormal.py | 102 +++++++++++++++++++++++++ 2 files changed, 136 insertions(+), 16 deletions(-) diff --git a/pytensor_distributions/matrixnormal.py b/pytensor_distributions/matrixnormal.py index f107cc2..dacc394 100644 --- a/pytensor_distributions/matrixnormal.py +++ b/pytensor_distributions/matrixnormal.py @@ -10,16 +10,27 @@ solve_lower = partial(solve_triangular, lower=True) +def _broadcast_mu(mu, rowcov, colcov): + """Broadcast mu to the output shape implied by all parameters.""" + mu = pt.as_tensor(mu) + rowcov = pt.as_tensor(rowcov) + colcov = pt.as_tensor(colcov) + # Use zero-addition to trigger automatic broadcasting across batch dims + row_zeros = pt.zeros(rowcov.shape[:-1]) # (..., m) + col_zeros = pt.zeros(colcov.shape[:-1]) # (..., n) + return mu + row_zeros[..., :, None] * col_zeros[..., None, :] + + def mean(mu, rowcov, colcov): - return pt.as_tensor(mu) + return _broadcast_mu(mu, rowcov, colcov) def mode(mu, rowcov, colcov): - return pt.as_tensor(mu) + return _broadcast_mu(mu, rowcov, colcov) def median(mu, rowcov, colcov): - return pt.as_tensor(mu) + return _broadcast_mu(mu, rowcov, colcov) def var(mu, rowcov, colcov): @@ -27,7 +38,7 @@ def var(mu, rowcov, colcov): colcov = pt.as_tensor(colcov) row_diag = pt.diagonal(rowcov, axis1=-2, axis2=-1) col_diag = pt.diagonal(colcov, axis1=-2, axis2=-1) - return pt.outer(row_diag, col_diag) + return row_diag[..., :, None] * col_diag[..., None, :] def std(mu, rowcov, colcov): @@ -35,11 +46,11 @@ def std(mu, rowcov, colcov): def skewness(mu, rowcov, colcov): - return pt.zeros_like(pt.as_tensor(mu)) + return pt.zeros_like(_broadcast_mu(mu, rowcov, colcov)) def kurtosis(mu, rowcov, colcov): - return pt.zeros_like(pt.as_tensor(mu)) + return pt.zeros_like(_broadcast_mu(mu, rowcov, colcov)) def entropy(mu, rowcov, colcov): @@ -74,9 +85,9 @@ def logpdf(X, mu, rowcov, colcov): # Compute tr[V^-1 (X-M)^T U^-1 (X-M)] via Cholesky solves # Using vec identity: quadform = ||L_U^-1 delta L_V^-T||^2_F - Y = solve_lower(chol_row, delta) # L_U^-1 delta, shape (m, n) - Z = solve_lower(chol_col, Y.T) # L_V^-1 (L_U^-1 delta)^T, shape (n, m) - quadform = pt.sum(Z**2) + Y = solve_lower(chol_row, delta) # L_U^-1 delta, shape (..., m, n) + Z = solve_lower(chol_col, pt.swapaxes(Y, -1, -2)) # L_V^-1 (L_U^-1 delta)^T, shape (..., n, m) + quadform = pt.sum(Z**2, axis=(-2, -1)) log_norm = 0.5 * mn * pt.log(2 * pt.pi) + 0.5 * n * logdet_row + 0.5 * m * logdet_col @@ -92,15 +103,22 @@ def rvs(mu, rowcov, colcov, size=None, random_state=None): rowcov = pt.as_tensor(rowcov) colcov = pt.as_tensor(colcov) - m = mu.shape[-2] - n = mu.shape[-1] - L_row = pt.linalg.cholesky(rowcov, lower=True) L_col = pt.linalg.cholesky(colcov, lower=True) if size is None: - Z = pt.random.normal(0, 1, size=(m, n), rng=random_state) - return mu + L_row @ Z @ L_col.T + size = () + elif not isinstance(size, tuple): + size = (size,) + + # Get the broadcast output shape from parameters + target = _broadcast_mu(mu, rowcov, colcov) # (..., m, n) + base_shape = target.shape # symbolic shape vector + + if size: + full_shape = pt.concatenate([pt.as_tensor(size), base_shape]) else: - Z = pt.random.normal(0, 1, size=(size, m, n), rng=random_state) - return mu + L_row @ Z @ L_col.T + full_shape = base_shape + + Z = pt.random.normal(0, 1, size=full_shape, rng=random_state) + return target + L_row @ Z @ pt.swapaxes(L_col, -1, -2) diff --git a/tests/test_matrixnormal.py b/tests/test_matrixnormal.py index b20cb3c..3746f63 100644 --- a/tests/test_matrixnormal.py +++ b/tests/test_matrixnormal.py @@ -108,3 +108,105 @@ def test_matrixnormal_rvs(M, U, V): atol=0.2, err_msg="Sample variance should approximate theoretical variance", ) + + +# --- Batched tests --- + +# Reuse a single test case for batched tests +M0 = np.array([[1.0, 2.0], [3.0, 4.0]]) +U0 = np.array([[2.0, 0.5], [0.5, 1.0]]) +V0 = np.eye(2) + + +class TestBatchedLogpdf: + """Test logpdf with batched observations.""" + + def test_batch_of_observations(self): + Xs = np.array([M0 + 0.1, M0 - 0.1, M0 + 0.5]) # (3, 2, 2) + result = MatrixNormal.logpdf( + pt.constant(Xs), pt.constant(M0), pt.constant(U0), pt.constant(V0) + ) + actual = result.eval() + + scipy_dist = scipy_matrix_normal(mean=M0, rowcov=U0, colcov=V0) + expected = np.array([scipy_dist.logpdf(x) for x in Xs]) + + assert actual.shape == (3,) + assert_allclose(actual, expected, rtol=1e-5) + + def test_single_observation_still_scalar(self): + X = M0 + 0.1 + result = MatrixNormal.logpdf( + pt.constant(X), pt.constant(M0), pt.constant(U0), pt.constant(V0) + ) + actual = result.eval() + expected = scipy_matrix_normal(mean=M0, rowcov=U0, colcov=V0).logpdf(X) + assert actual.shape == () + assert_allclose(actual, expected, rtol=1e-5) + + +class TestBatchedVar: + """Test var with batched covariance matrices.""" + + def test_batched_rowcov(self): + rowcovs = np.array([np.eye(2), 2 * np.eye(2)]) # (2, 2, 2) + result = MatrixNormal.var(pt.constant(M0), pt.constant(rowcovs), pt.constant(V0)) + actual = result.eval() + assert actual.shape == (2, 2, 2) + for i in range(2): + expected = np.outer(np.diag(rowcovs[i]), np.diag(V0)) + assert_allclose(actual[i], expected, rtol=1e-10) + + def test_unbatched_still_works(self): + result = MatrixNormal.var(pt.constant(M0), pt.constant(U0), pt.constant(V0)) + actual = result.eval() + expected = np.outer(np.diag(U0), np.diag(V0)) + assert_allclose(actual, expected, rtol=1e-10) + + +class TestBatchedRvs: + """Test rvs with tuple sizes and batched parameters.""" + + def test_size_as_tuple(self): + result = MatrixNormal.rvs(pt.constant(M0), pt.constant(U0), pt.constant(V0), size=(3, 5)) + samples = result.eval() + assert samples.shape == (3, 5, 2, 2) + + def test_size_as_int(self): + result = MatrixNormal.rvs(pt.constant(M0), pt.constant(U0), pt.constant(V0), size=10) + samples = result.eval() + assert samples.shape == (10, 2, 2) + + def test_size_none(self): + result = MatrixNormal.rvs(pt.constant(M0), pt.constant(U0), pt.constant(V0), size=None) + sample = result.eval() + assert sample.shape == (2, 2) + + def test_batched_params(self): + mus = np.stack([M0, M0 + 1]) # (2, 2, 2) + result = MatrixNormal.rvs(pt.constant(mus), pt.constant(U0), pt.constant(V0), size=None) + samples = result.eval() + assert samples.shape == (2, 2, 2) + + def test_batched_params_with_size(self): + mus = np.stack([M0, M0 + 1]) # (2, 2, 2) + result = MatrixNormal.rvs(pt.constant(mus), pt.constant(U0), pt.constant(V0), size=(5,)) + samples = result.eval() + assert samples.shape == (5, 2, 2, 2) + + +class TestBatchedMeanModeMedian: + """Test mean/mode/median broadcasting with batched parameters.""" + + def test_mean_broadcasts(self): + rowcovs = np.array([np.eye(2), 2 * np.eye(2)]) # (2, 2, 2) + result = MatrixNormal.mean(pt.constant(M0), pt.constant(rowcovs), pt.constant(V0)) + actual = result.eval() + assert actual.shape == (2, 2, 2) + for i in range(2): + assert_allclose(actual[i], M0) + + def test_unbatched_mean(self): + result = MatrixNormal.mean(pt.constant(M0), pt.constant(U0), pt.constant(V0)) + actual = result.eval() + assert_allclose(actual, M0) From 6d6b52f2839fd15773ec1c4cf2dc9c603ed35db3 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Mon, 23 Feb 2026 14:32:19 -0600 Subject: [PATCH 3/3] Additional test coverage --- tests/test_matrixnormal.py | 64 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/tests/test_matrixnormal.py b/tests/test_matrixnormal.py index 3746f63..e01ff81 100644 --- a/tests/test_matrixnormal.py +++ b/tests/test_matrixnormal.py @@ -210,3 +210,67 @@ def test_unbatched_mean(self): result = MatrixNormal.mean(pt.constant(M0), pt.constant(U0), pt.constant(V0)) actual = result.eval() assert_allclose(actual, M0) + + def test_mode_broadcasts(self): + rowcovs = np.array([np.eye(2), 2 * np.eye(2)]) + actual = MatrixNormal.mode(pt.constant(M0), pt.constant(rowcovs), pt.constant(V0)).eval() + assert actual.shape == (2, 2, 2) + + def test_median_broadcasts(self): + rowcovs = np.array([np.eye(2), 2 * np.eye(2)]) + actual = MatrixNormal.median(pt.constant(M0), pt.constant(rowcovs), pt.constant(V0)).eval() + assert actual.shape == (2, 2, 2) + + +class TestBatchedOtherFunctions: + """Test remaining functions with batched inputs.""" + + def test_var_batched_colcov(self): + colcovs = np.array([np.eye(2), 3 * np.eye(2)]) + actual = MatrixNormal.var(pt.constant(M0), pt.constant(U0), pt.constant(colcovs)).eval() + assert actual.shape == (2, 2, 2) + for i in range(2): + expected = np.outer(np.diag(U0), np.diag(colcovs[i])) + assert_allclose(actual[i], expected, rtol=1e-10) + + def test_std_batched(self): + rowcovs = np.array([np.eye(2), 2 * np.eye(2)]) + actual = MatrixNormal.std(pt.constant(M0), pt.constant(rowcovs), pt.constant(V0)).eval() + assert actual.shape == (2, 2, 2) + var_result = MatrixNormal.var(pt.constant(M0), pt.constant(rowcovs), pt.constant(V0)).eval() + assert_allclose(actual, np.sqrt(var_result), rtol=1e-10) + + def test_skewness_batched(self): + rowcovs = np.array([np.eye(2), 2 * np.eye(2)]) + actual = MatrixNormal.skewness( + pt.constant(M0), pt.constant(rowcovs), pt.constant(V0) + ).eval() + assert actual.shape == (2, 2, 2) + assert_allclose(actual, 0) + + def test_kurtosis_batched(self): + rowcovs = np.array([np.eye(2), 2 * np.eye(2)]) + actual = MatrixNormal.kurtosis( + pt.constant(M0), pt.constant(rowcovs), pt.constant(V0) + ).eval() + assert actual.shape == (2, 2, 2) + assert_allclose(actual, 0) + + def test_entropy_batched(self): + rowcovs = np.array([U0, 2 * np.eye(2)]) + actual = MatrixNormal.entropy(pt.constant(M0), pt.constant(rowcovs), pt.constant(V0)).eval() + assert actual.shape == (2,) + for i in range(2): + expected = scipy_matrix_normal(mean=M0, rowcov=rowcovs[i], colcov=V0).entropy() + assert_allclose(actual[i], expected, rtol=1e-5) + + def test_logpdf_batched_params(self): + X = M0 + 0.1 + rowcovs = np.array([U0, 2 * np.eye(2)]) + actual = MatrixNormal.logpdf( + pt.constant(X), pt.constant(M0), pt.constant(rowcovs), pt.constant(V0) + ).eval() + assert actual.shape == (2,) + for i in range(2): + expected = scipy_matrix_normal(mean=M0, rowcov=rowcovs[i], colcov=V0).logpdf(X) + assert_allclose(actual[i], expected, rtol=1e-5)