Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions src/hssm/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def __init__(
link: str | bmb.Link | None = None,
bounds: tuple[float, float] | None = None,
):
if name is None:
raise ValueError("A name must be specified.")
self.name = name
self.prior = prior
self.formula = formula
Expand Down Expand Up @@ -211,7 +213,6 @@ def override_default_priors_ddm(self, data: pd.DataFrame, eval_env: dict[str, An
The environment used to evaluate the formula.
"""
self._ensure_not_converted(context="prior")
assert self.name is not None

if not self.is_regression:
return
Expand Down Expand Up @@ -319,10 +320,10 @@ def convert(self):
if any(not np.isscalar(bound) for bound in self.bounds):
raise ValueError(f"The bounds of {self.name} should both be scalar.")
lower, upper = self.bounds
assert lower < upper, (
f"The lower bound of {self.name} should be less than "
+ "its upper bound."
)
if not lower < upper:
raise ValueError(
f"{self.name}: lower bound must be less than upper bound."
)

if isinstance(self.prior, int):
self.prior = float(self.prior)
Expand Down Expand Up @@ -452,7 +453,6 @@ def parse_bambi(
link = {self.name: self.link}
return formula, prior, link

assert self.name is not None
if self.prior is not None:
prior = {self.name: self.prior}
if self.link is not None:
Expand All @@ -470,7 +470,6 @@ def __repr__(self) -> str:
regression or not.
"""
output = []
assert self.name is not None
output.append(self.name + ":")

# Simplest case: float
Expand All @@ -482,12 +481,15 @@ def __repr__(self) -> str:
# Regression case:
# Output formula, priors, and link functions
if self.is_regression:
assert self.formula is not None
if self.formula is None:
raise ValueError("Formula must be specified for regression.")

output.append(f" Formula: {self.formula}")
output.append(" Priors:")

if self.prior is not None:
assert isinstance(self.prior, dict)
if not isinstance(self.prior, dict):
raise TypeError("The prior for a regression must be a dict.")

for param, prior in self.prior.items():
output.append(f" {param} ~ {prior}")
Expand All @@ -505,7 +507,8 @@ def __repr__(self) -> str:
# None regression case:
# Output prior and bounds
else:
assert isinstance(self.prior, bmb.Prior)
if not isinstance(self.prior, bmb.Prior):
raise TypeError("The prior must be an instance of bmb.Prior.")
output.append(f" Prior: {self.prior}")

output.append(f" Explicit bounds: {self.bounds}")
Expand Down