Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 47 additions & 13 deletions conjugate/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,9 @@ def from_inverse_gamma(
def inverse_gamma(self) -> InverseGamma:
return InverseGamma(alpha=self.alpha, beta=self.beta)

def sample_variance(self, size: int, random_state=None) -> NUMERIC:
def sample_variance(
self, size: int, random_state: np.random.RandomState | None = None
) -> NUMERIC:
"""Sample variance from the inverse gamma distribution.

Args:
Expand All @@ -644,11 +646,15 @@ def sample_variance(self, size: int, random_state=None) -> NUMERIC:
"""
return self.inverse_gamma.dist.rvs(size=size, random_state=random_state)

def _sample_beta_1d(self, variance, size: int, random_state=None) -> NUMERIC:
def _sample_beta_1d(
self, variance, size: int, random_state: np.random.RandomState | None = None
) -> NUMERIC:
sigma = (variance / self.nu) ** 0.5
return stats.norm(self.mu, sigma).rvs(size=size, random_state=random_state)

def _sample_beta_nd(self, variance, size: int, random_state=None) -> NUMERIC:
def _sample_beta_nd(
self, variance, size: int, random_state: np.random.RandomState | None = None
) -> NUMERIC:
variance = (self.delta_inverse[None, ...].T * variance).T
return np.stack(
[
Expand All @@ -663,7 +669,7 @@ def sample_mean(
self,
size: int,
return_variance: bool = False,
random_state=None,
random_state: np.random.RandomState | None = None,
) -> NUMERIC | tuple[NUMERIC, NUMERIC]:
"""Sample the mean from the normal distribution.

Expand All @@ -681,7 +687,10 @@ def sample_mean(
)

def sample_beta(
self, size: int, return_variance: bool = False, random_state=None
self,
size: int,
return_variance: bool = False,
random_state: np.random.RandomState | None = None,
) -> NUMERIC | tuple[NUMERIC, NUMERIC]:
"""Sample beta from the normal distribution.

Expand Down Expand Up @@ -809,7 +818,11 @@ class GammaKnownRateProportional:
c: NUMERIC

def approx_log_likelihood(
self, alpha: NUMERIC, beta: NUMERIC, ln=np.log, gammaln=gammaln
self,
alpha: NUMERIC,
beta: NUMERIC,
ln: Callable = np.log,
gammaln: Callable = gammaln,
) -> NUMERIC:
"""Approximate log likelihood.

Expand Down Expand Up @@ -848,7 +861,11 @@ class GammaProportional:
s: NUMERIC

def approx_log_likelihood(
self, alpha: NUMERIC, beta: NUMERIC, ln=np.log, gammaln=gammaln
self,
alpha: NUMERIC,
beta: NUMERIC,
ln: Callable = np.log,
gammaln: Callable = gammaln,
) -> NUMERIC:
"""Approximate log likelihood.

Expand Down Expand Up @@ -886,7 +903,11 @@ class BetaProportional:
k: NUMERIC

def approx_log_likelihood(
self, alpha: NUMERIC, beta: NUMERIC, ln=np.log, gammaln=gammaln
self,
alpha: NUMERIC,
beta: NUMERIC,
ln: Callable = np.log,
gammaln: Callable = gammaln,
) -> NUMERIC:
"""Approximate log likelihood.

Expand Down Expand Up @@ -946,7 +967,13 @@ class VonMisesKnownConcentration:
a: NUMERIC
b: NUMERIC

def log_likelihood(self, mu: NUMERIC, cos=np.cos, ln=np.log, i0=i0) -> NUMERIC:
def log_likelihood(
self,
mu: NUMERIC,
cos: Callable = np.cos,
ln: Callable = np.log,
i0: Callable = i0,
) -> NUMERIC:
"""Approximate log likelihood.

Args:
Expand Down Expand Up @@ -976,7 +1003,9 @@ class VonMisesKnownDirectionProportional:
c: NUMERIC
r: NUMERIC

def approx_log_likelihood(self, kappa: NUMERIC, ln=np.log, i0=i0) -> NUMERIC:
def approx_log_likelihood(
self, kappa: NUMERIC, ln: Callable = np.log, i0: Callable = i0
) -> NUMERIC:
"""Approximate log likelihood.

Args:
Expand Down Expand Up @@ -1058,7 +1087,9 @@ class NormalGamma:
def gamma(self) -> Gamma:
return Gamma(alpha=self.alpha, beta=self.beta)

def sample_variance(self, size: int, random_state=None) -> NUMERIC:
def sample_variance(
self, size: int, random_state: np.random.RandomState | None = None
) -> NUMERIC:
"""Sample precision from gamma distribution and invert.

Args:
Expand All @@ -1077,7 +1108,7 @@ def sample_mean(
self,
size: int,
return_variance: bool = False,
random_state=None,
random_state: np.random.RandomState | None = None,
) -> NUMERIC | tuple[NUMERIC, NUMERIC]:
"""Sample mean from the normal distribution.

Expand All @@ -1095,7 +1126,10 @@ def sample_mean(
)

def sample_beta(
self, size: int, return_variance: bool = False, random_state=None
self,
size: int,
return_variance: bool = False,
random_state: np.random.RandomState | None = None,
) -> NUMERIC | tuple[NUMERIC, NUMERIC]:
"""Sample beta from the normal distribution.

Expand Down
Loading