diff --git a/src/pysatl_core/distributions/support.py b/src/pysatl_core/distributions/support.py index c8c8181..15380c0 100644 --- a/src/pysatl_core/distributions/support.py +++ b/src/pysatl_core/distributions/support.py @@ -14,13 +14,20 @@ __copyright__ = "Copyright (c) 2025 PySATL project" __license__ = "SPDX-License-Identifier: MIT" +from collections.abc import Callable from dataclasses import dataclass from math import floor -from typing import TYPE_CHECKING, Protocol, cast, overload, runtime_checkable +from typing import ( + TYPE_CHECKING, + Protocol, + cast, + overload, + runtime_checkable, +) import numpy as np -from pysatl_core.types import BoolArray, Interval1D, Number, NumericArray +from pysatl_core.types import BoolArray, Interval1D, IntervalND, Number, NumericArray if TYPE_CHECKING: from collections.abc import Iterable, Iterator @@ -49,6 +56,15 @@ class ContinuousSupport(Interval1D, Support): """ +class ContinuousNDSupport(IntervalND, Support): # type: ignore[misc] + """ + Support for continuous distributions represented as an array of intervals. + + This class inherits from IntervalND and implements the Support protocol + for continuous distributions defined on a list of intervals [left, right]. + """ + + @runtime_checkable class DiscreteSupport(Support, Protocol): """ @@ -430,10 +446,26 @@ def is_right_bounded(self) -> bool: __iter__ = iter_points +class SupportByPredicate: + def __init__(self, predicate: Callable[[NumericArray | Number], bool]): + self._predicate = predicate + + def __contains__(self, item: NumericArray | Number) -> bool: + return self._predicate(item) + + +class SupportByIntervals(SupportByPredicate): + def __init__(self, support: ContinuousNDSupport): + SupportByPredicate.__init__(self, lambda x: x in support) + + __all__ = [ # Base support protocol "Support", "ContinuousSupport", + "ContinuousNDSupport", + "SupportByPredicate", + "SupportByIntervals", # Discrete support protocol and implementations "DiscreteSupport", "ExplicitTableDiscreteSupport", diff --git a/src/pysatl_core/families/__init__.py b/src/pysatl_core/families/__init__.py index ed30528..3eb3fb8 100644 --- a/src/pysatl_core/families/__init__.py +++ b/src/pysatl_core/families/__init__.py @@ -14,6 +14,12 @@ from .builtins import __all__ as _builtins_all from .configuration import configure_families_register from .distribution import ParametricFamilyDistribution +from .exponential_family import ( + # CanonicalContinuousExponentialClassFamily, + ContinuousExponentialClassFamily, + ExponentialConjugateHyperparameters, + ExponentialFamilyParametrization, +) from .parametric_family import ParametricFamily from .parametrizations import ( Parametrization, @@ -34,6 +40,10 @@ "configure_families_register", # builtins *_builtins_all, + "ContinuousExponentialClassFamily", + "ExponentialFamilyParametrization", + "ExponentialConjugateHyperparameters", + # "CanonicalContinuousExponentialClassFamily", ] del _builtins_all diff --git a/src/pysatl_core/families/exponential_family.py b/src/pysatl_core/families/exponential_family.py new file mode 100644 index 0000000..5f27763 --- /dev/null +++ b/src/pysatl_core/families/exponential_family.py @@ -0,0 +1,270 @@ +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, cast + +import numpy as np +from scipy.differentiate import jacobian +from scipy.integrate import nquad +from scipy.linalg import det + +from pysatl_core.distributions.support import ( + ContinuousSupport, + SupportByPredicate, +) +from pysatl_core.families.parametric_family import ParametricFamily +from pysatl_core.families.parametrizations import Parametrization, parametrization +from pysatl_core.types import ( + CharacteristicName, + DistributionType, + GenericCharacteristicName, + ParametrizationName, +) + +if TYPE_CHECKING: + from pysatl_core.distributions.support import Support + from pysatl_core.types import Number, NumericArray + + type ParametrizedFunction = Callable[[Parametrization, Any], Any] + type SupportArg = Callable[[Parametrization], Support | None] | None + type NumberParameter = Number | NumericArray + + +@dataclass +class ExponentialFamilyParametrization(Parametrization): + """ + Standard parametrization of Exponential Family. + """ + + theta: NumberParameter + + def transform_to_base_parametrization(self) -> ExponentialFamilyParametrization: + return self + + +@dataclass +class ExponentialConjugateHyperparameters: + effective_suff_stat_value: NumberParameter + effective_sample_size: int + + +class ContinuousExponentialClassFamily(ParametricFamily): + """ + Representation of exponential class with density = h(x) * exp( + A(t)), + where canonical parametrization is that, when n = t + + Usage of this class: + - you can use method transform_to_another to replace x to smth else, for example, into + """ + + def __init__( + self, + *, + log_partition: Callable[[NumberParameter], NumberParameter], + sufficient_statistics: Callable[[NumberParameter], NumberParameter], + normalization_constant: Callable[[NumberParameter], NumberParameter], + support: SupportByPredicate, + parameter_space: SupportByPredicate, + sufficient_statistics_values: SupportByPredicate, + name: str = "ExponentialFamily", + distr_type: DistributionType | Callable[[Parametrization], DistributionType], + distr_parametrizations: list[ParametrizationName], + support_by_parametrization: SupportArg = None, + ): + self._sufficient = sufficient_statistics + self._log_partition = log_partition + self._normalization = normalization_constant + + self._support = support + self._parameter_space = parameter_space + self._sufficient_statistics_values = sufficient_statistics_values + + distr_characteristics: dict[ + GenericCharacteristicName, + dict[ParametrizationName, ParametrizedFunction] | ParametrizedFunction, + ] = { + CharacteristicName.PDF: self.density, + CharacteristicName.MEAN: self._mean, + CharacteristicName.VAR: self._var, + } + + ParametricFamily.__init__( + self, + name=name, + distr_type=distr_type, + distr_parametrizations=distr_parametrizations, + distr_characteristics=distr_characteristics, + support_by_parametrization=support_by_parametrization, + ) + parametrization(family=self, name="theta")(ExponentialFamilyParametrization) + + @property + def log_density(self) -> ParametrizedFunction: + def log_density_func(parametrization: Parametrization, x: NumberParameter) -> Number: + parametrization = cast(ExponentialFamilyParametrization, parametrization) + parametrization = parametrization.transform_to_base_parametrization() + if x not in self._support: + return -np.inf + + theta = parametrization.theta + sufficient = self._sufficient(x) + dot = np.dot(theta, sufficient) + if hasattr(dot, "__len__"): + dot = dot[0] + + result = np.log(self._normalization(x)) + dot + self._log_partition(theta) + return cast(np.floating, result.item()) + + return log_density_func + + @property + def density(self) -> ParametrizedFunction: + return lambda parametrization, x: np.exp(self.log_density(parametrization, x)) + + @property + def conjugate_prior_family(self) -> ContinuousExponentialClassFamily: + def conjugate_sufficient( + theta: NumberParameter, + ) -> NumberParameter: + if not hasattr(theta, "__len__"): + theta = np.array([theta]) + + if theta not in self._parameter_space: + return np.full(len(theta) + 1, float("-inf")) + return np.append(theta, self._log_partition(theta)) + + def conjugate_log_partition( + parametrization: NumberParameter, + ) -> NumberParameter: + def pdf(theta: NumberParameter) -> NumberParameter: + if not hasattr(theta, "__len__"): + theta = np.array([theta]) + return cast( + np.floating, + np.exp( + np.dot( + conjugate_sufficient(theta), + parametrization, + ) + ).item(), + ) + + all_value = nquad( + lambda x: pdf(x) if x in self._parameter_space else 0, # type: ignore[arg-type] + [(float("-inf"), float("+inf"))], + )[0] + return cast(np.float64, -np.log(all_value)) + + def conjugate_sufficient_accepts( + theta: NumericArray, + ) -> bool: + xi = theta[:-1] + nu = theta[-1] + + return xi in self._sufficient_statistics_values and nu in ContinuousSupport(0, np.inf) + + return ContinuousExponentialClassFamily( + log_partition=conjugate_log_partition, + sufficient_statistics=conjugate_sufficient, + normalization_constant=lambda _: 1, + support=self._parameter_space, + sufficient_statistics_values=self._parameter_space, # TODO: write convex hull for this + parameter_space=SupportByPredicate(conjugate_sufficient_accepts), # type: ignore[arg-type] + name=self.name, + distr_type=self._distr_type, + distr_parametrizations=self.parametrization_names, + support_by_parametrization=self.support_resolver, + ) + + def transform( + self, + transform_function: Callable[[Any], Any], + ) -> ContinuousExponentialClassFamily: + def calculate_jacobian(x: Any) -> Any: + if type(x) is not list: + x = np.array([x]) + + return np.abs(det(jacobian(transform_function, x).df)) + + def new_support(x: Any) -> bool: + return transform_function(x) in self._support + + def new_sufficient(x: Any) -> Any: + return self._sufficient(transform_function(x)) + + def new_normalization(x: Any) -> Any: + return self._normalization(x) * calculate_jacobian(x) + + return ContinuousExponentialClassFamily( + log_partition=self._log_partition, + sufficient_statistics=new_sufficient, + normalization_constant=new_normalization, + support=SupportByPredicate(new_support), + parameter_space=self._parameter_space, + sufficient_statistics_values=self._sufficient_statistics_values, + name=f"Transformed{self._name}", + distr_type=self._distr_type, + distr_parametrizations=self.parametrization_names, + support_by_parametrization=self.support_resolver, + ) + + @property + def _mean(self) -> ParametrizedFunction: + def mean_func(parametrization: Parametrization, x: Any) -> Any: + parametrization = cast(ExponentialFamilyParametrization, parametrization) + dimension_size = 1 + if hasattr(x, "__len__"): + dimension_size = len(x) + return nquad( + lambda x: ( # type: ignore[arg-type] + np.dot(x, self.density(parametrization, x)) if x in self._support else 0 + ), + [(float("-inf"), float("inf"))] * dimension_size, + )[0] + + return mean_func + + @property + def _second_moment(self) -> ParametrizedFunction: + def func(parametrization: Parametrization, x: Any) -> Any: + parametrization = cast(ExponentialFamilyParametrization, parametrization) + dimension_size = 1 + if hasattr(x, "__len__"): + dimension_size = len(x) + return nquad( + lambda x: ( # type: ignore[arg-type] + x**2 * self.density(parametrization, x) if x in self._support else 0 + ), + [(float("-inf"), float("inf"))] * dimension_size, + )[0] + + return func + + @property + def _var(self) -> ParametrizedFunction: + def func(parametrization: Parametrization, x: Any) -> Any: + parametrization = cast(ExponentialFamilyParametrization, parametrization) + return self._second_moment(parametrization, x) - self._mean(parametrization, x) ** 2 + + return func + + def posterior_hyperparameters( + self, prior_hyper: ExponentialConjugateHyperparameters, sample: list[Any] + ) -> ExponentialConjugateHyperparameters: + posterior_effective_suff_stat_value = prior_hyper.effective_suff_stat_value + posterior_effective_sample_size = prior_hyper.effective_sample_size + if hasattr(sample, "__iter__") and not isinstance(sample, str): + posterior_effective_suff_stat_value += np.sum( + [self._sufficient(x) for x in sample], # type: ignore[arg-type] + axis=0, + ) + posterior_effective_sample_size += len(sample) + else: + posterior_effective_suff_stat_value += self._sufficient(sample) # type: ignore[arg-type] + posterior_effective_sample_size += 1 + + return ExponentialConjugateHyperparameters( + effective_suff_stat_value=posterior_effective_suff_stat_value, + effective_sample_size=posterior_effective_sample_size, + ) diff --git a/src/pysatl_core/types.py b/src/pysatl_core/types.py index 150c9f1..ce9d84b 100644 --- a/src/pysatl_core/types.py +++ b/src/pysatl_core/types.py @@ -240,6 +240,25 @@ def shape(self) -> ContinuousSupportShape1D: type Method[In, Out] = AnalyticalComputation[In, Out] | FittedComputationMethod[In, Out] """Type alias for a distribution computation method (analytical or fitted).""" + +@dataclass(frozen=True, slots=True) +class IntervalND: + intervals: list[Interval1D] + + def contains(self, x: Number | NumericArray) -> bool | BoolArray: + if not hasattr(x, "__iter__"): + x = np.array([x]) + + return all( + x_coordinate in interval + for interval, x_coordinate in zip(self.intervals, x, strict=True) + ) + + def __contains__(self, x: object) -> bool: + """Check if a single point is in the interval.""" + return bool(self.contains(cast(Number, x))) + + type GenericCharacteristicName = str """Type alias for characteristic names (e.g., 'pdf', 'cdf').""" @@ -314,6 +333,7 @@ class FamilyName(StrEnum): "ComputationFunc", "DistributionType", "Interval1D", + "IntervalND", "ContinuousSupportShape1D", "BoolArray", "NumPyNumber", diff --git a/tests/unit/families/test_exponential_family.py b/tests/unit/families/test_exponential_family.py new file mode 100644 index 0000000..feb19a4 --- /dev/null +++ b/tests/unit/families/test_exponential_family.py @@ -0,0 +1,87 @@ +from typing import cast + +import numpy as np +import pytest +import scipy +from numpy.testing import assert_allclose + +from pysatl_core.distributions.support import ContinuousNDSupport, SupportByIntervals +from pysatl_core.families import ( + ContinuousExponentialClassFamily, +) +from pysatl_core.families.registry import ParametricFamilyRegister +from pysatl_core.types import Interval1D, UnivariateContinuous + + +def gamma_pdf(alpha: float, beta: float, x: float) -> float: + return scipy.stats.gamma(a=alpha, scale=1 / beta).pdf(x).item() # type: ignore[attr-defined] + + +@pytest.fixture(scope="function") +def conjugate_for_exponential() -> ContinuousExponentialClassFamily: + def transform_function(x: list[float] | float) -> list[float] | float: + if type(x) is list: + return [-x[0]] + return -x # type: ignore[operator] + + support_neg = SupportByIntervals(ContinuousNDSupport(intervals=[Interval1D(-np.inf, 0)])) + support_pos = SupportByIntervals(ContinuousNDSupport(intervals=[Interval1D(0, np.inf)])) + fam = ContinuousExponentialClassFamily( + log_partition=lambda parametrization: np.log(-parametrization), + sufficient_statistics=lambda x: x, + normalization_constant=lambda _: 1, + parameter_space=support_neg, + sufficient_statistics_values=support_pos, + support=support_pos, + distr_type=UnivariateContinuous, + distr_parametrizations=["theta"], + ) + + conjugate_fam = fam.conjugate_prior_family.transform(transform_function) + ParametricFamilyRegister().register(conjugate_fam) + return cast( + ContinuousExponentialClassFamily, + ParametricFamilyRegister().get("TransformedExponentialFamily"), + ) + + +@pytest.mark.parametrize("theta1", range(2, 5)) +@pytest.mark.parametrize("theta2", range(2, 5)) +def test_exponential_pdf(theta1, theta2, conjugate_for_exponential): + gamma_family: ContinuousExponentialClassFamily = conjugate_for_exponential + + alpha = theta2 + 1 + beta = theta1 + + exponential = gamma_family(theta=np.array([theta1, theta2]), parametrization_name="theta") + pdf = exponential.computation_strategy.query_method("pdf", distr=exponential) + + x = [i / 10 for i in range(100)] + + assert_allclose([pdf(xx) for xx in x], [gamma_pdf(alpha, beta, xx) for xx in x], rtol=1e-6) + + +@pytest.mark.parametrize("theta1", range(2, 5)) +@pytest.mark.parametrize("theta2", range(2, 5)) +def test_exponential_mean(theta1, theta2, conjugate_for_exponential): + gamma_family: ContinuousExponentialClassFamily = conjugate_for_exponential + + alpha = theta2 + 1 + beta = theta1 + + exponential = gamma_family(theta=np.array([theta1, theta2]), parametrization_name="theta") + mean = exponential.computation_strategy.query_method("mean", distr=exponential) + assert np.isclose(mean(12), alpha / beta, rtol=1e-6) + + +@pytest.mark.parametrize("theta1", range(2, 5)) +@pytest.mark.parametrize("theta2", range(2, 5)) +def test_exponential_var(theta1, theta2, conjugate_for_exponential): + gamma_family: ContinuousExponentialClassFamily = conjugate_for_exponential + + alpha = theta2 + 1 + beta = theta1 + + exponential = gamma_family(theta=np.array([theta1, theta2]), parametrization_name="theta") + var = exponential.computation_strategy.query_method("var", distr=exponential) + assert np.isclose(var(12), alpha / beta**2, rtol=1e-6)