Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions .github/workflows/run_ruff.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
name: Ruff

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: chartboost/ruff-action@v1
- uses: chartboost/ruff-action@v1
with:
args: 'format --check'

3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,6 @@ docs/*.inv
build/*
dist/*
*.whl

# Local pre-commit hooks
.pre-commit-config.yaml
9 changes: 6 additions & 3 deletions RAT/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import os

from RAT import models
from RAT.classlist import ClassList
from RAT.project import Project
from RAT.controls import set_controls
from RAT.project import Project
from RAT.run import run
import RAT.models

__all__ = ["ClassList", "Project", "run", "set_controls", "models"]

dir_path = os.path.dirname(os.path.realpath(__file__))
os.environ["RAT_PATH"] = os.path.join(dir_path, '')
os.environ["RAT_PATH"] = os.path.join(dir_path, "")
81 changes: 56 additions & 25 deletions RAT/classlist.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""The classlist module. Contains the ClassList class, which defines a list containing instances of a particular class.
"""The classlist module. Contains the ClassList class, which defines a list containing instances of a particular
class.
"""

import collections
from collections.abc import Iterable, Sequence
import contextlib
import prettytable
from typing import Any, Union
import warnings
from collections.abc import Iterable, Sequence
from typing import Any, Union

import prettytable


class ClassList(collections.UserList):
Expand All @@ -31,7 +33,9 @@ class ClassList(collections.UserList):
An instance, or list of instance(s), of the class to be used in this ClassList.
name_field : str, optional
The field used to define unique objects in the ClassList (default is "name").

"""

def __init__(self, init_list: Union[Sequence[object], object] = None, name_field: str = "name") -> None:
self.name_field = name_field

Expand All @@ -56,7 +60,7 @@ def __repr__(self):
else:
if any(model.__dict__ for model in self.data):
table = prettytable.PrettyTable()
table.field_names = ['index'] + [key.replace('_', ' ') for key in self.data[0].__dict__.keys()]
table.field_names = ["index"] + [key.replace("_", " ") for key in self.data[0].__dict__]
table.add_rows([[index] + list(model.__dict__.values()) for index, model in enumerate(self.data)])
output = table.get_string()
else:
Expand All @@ -81,15 +85,15 @@ def _delitem(self, index: int) -> None:
"""Auxiliary routine of "__delitem__" used to enable wrapping."""
del self.data[index]

def __iadd__(self, other: Sequence[object]) -> 'ClassList':
def __iadd__(self, other: Sequence[object]) -> "ClassList":
"""Define in-place addition using the "+=" operator."""
return self._iadd(other)

def _iadd(self, other: Sequence[object]) -> 'ClassList':
def _iadd(self, other: Sequence[object]) -> "ClassList":
"""Auxiliary routine of "__iadd__" used to enable wrapping."""
if other and not (isinstance(other, Sequence) and not isinstance(other, str)):
other = [other]
if not hasattr(self, '_class_handle'):
if not hasattr(self, "_class_handle"):
self._class_handle = self._determine_class_handle(self + other)
self._check_classes(self + other)
self._check_unique_name_fields(self + other)
Expand Down Expand Up @@ -129,20 +133,27 @@ def append(self, obj: object = None, **kwargs) -> None:
SyntaxWarning
Raised if the input arguments contain BOTH an object and keyword arguments. In this situation the object is
appended to the ClassList and the keyword arguments are discarded.

"""
if obj and kwargs:
warnings.warn('ClassList.append() called with both an object and keyword arguments. '
'The keyword arguments will be ignored.', SyntaxWarning)
warnings.warn(
"ClassList.append() called with both an object and keyword arguments. "
"The keyword arguments will be ignored.",
SyntaxWarning,
stacklevel=2,
)
if obj:
if not hasattr(self, '_class_handle'):
if not hasattr(self, "_class_handle"):
self._class_handle = type(obj)
self._check_classes(self + [obj])
self._check_unique_name_fields(self + [obj])
self.data.append(obj)
else:
if not hasattr(self, '_class_handle'):
raise TypeError('ClassList.append() called with keyword arguments for a ClassList without a class '
'defined. Call ClassList.append() with an object to define the class.')
if not hasattr(self, "_class_handle"):
raise TypeError(
"ClassList.append() called with keyword arguments for a ClassList without a class "
"defined. Call ClassList.append() with an object to define the class.",
)
self._validate_name_field(kwargs)
self.data.append(self._class_handle(**kwargs))

Expand All @@ -169,20 +180,27 @@ def insert(self, index: int, obj: object = None, **kwargs) -> None:
SyntaxWarning
Raised if the input arguments contain both an object and keyword arguments. In this situation the object is
inserted into the ClassList and the keyword arguments are discarded.

"""
if obj and kwargs:
warnings.warn('ClassList.insert() called with both an object and keyword arguments. '
'The keyword arguments will be ignored.', SyntaxWarning)
warnings.warn(
"ClassList.insert() called with both an object and keyword arguments. "
"The keyword arguments will be ignored.",
SyntaxWarning,
stacklevel=2,
)
if obj:
if not hasattr(self, '_class_handle'):
if not hasattr(self, "_class_handle"):
self._class_handle = type(obj)
self._check_classes(self + [obj])
self._check_unique_name_fields(self + [obj])
self.data.insert(index, obj)
else:
if not hasattr(self, '_class_handle'):
raise TypeError('ClassList.insert() called with keyword arguments for a ClassList without a class '
'defined. Call ClassList.insert() with an object to define the class.')
if not hasattr(self, "_class_handle"):
raise TypeError(
"ClassList.insert() called with keyword arguments for a ClassList without a class "
"defined. Call ClassList.insert() with an object to define the class.",
)
self._validate_name_field(kwargs)
self.data.insert(index, self._class_handle(**kwargs))

Expand All @@ -209,7 +227,7 @@ def extend(self, other: Sequence[object]) -> None:
"""Extend the ClassList by adding another sequence."""
if other and not (isinstance(other, Sequence) and not isinstance(other, str)):
other = [other]
if not hasattr(self, '_class_handle'):
if not hasattr(self, "_class_handle"):
self._class_handle = self._determine_class_handle(self + other)
self._check_classes(self + other)
self._check_unique_name_fields(self + other)
Expand All @@ -229,6 +247,7 @@ def get_names(self) -> list[str]:
-------
names : list [str]
The value of the name_field attribute of each object in the ClassList.

"""
return [getattr(model, self.name_field) for model in self.data if hasattr(model, self.name_field)]

Expand All @@ -244,9 +263,14 @@ def get_all_matches(self, value: Any) -> list[tuple]:
-------
: list [tuple]
A list of (index, field) tuples matching the given value.

"""
return [(index, field) for index, element in enumerate(self.data) for field in vars(element)
if getattr(element, field) == value]
return [
(index, field)
for index, element in enumerate(self.data)
for field in vars(element)
if getattr(element, field) == value
]

def _validate_name_field(self, input_args: dict[str, Any]) -> None:
"""Raise a ValueError if the name_field attribute is passed as an object parameter, and its value is already
Expand All @@ -261,12 +285,15 @@ def _validate_name_field(self, input_args: dict[str, Any]) -> None:
------
ValueError
Raised if the input arguments contain a name_field value already defined in the ClassList.

"""
names = self.get_names()
with contextlib.suppress(KeyError):
if input_args[self.name_field] in names:
raise ValueError(f"Input arguments contain the {self.name_field} '{input_args[self.name_field]}', "
f"which is already specified in the ClassList")
raise ValueError(
f"Input arguments contain the {self.name_field} '{input_args[self.name_field]}', "
f"which is already specified in the ClassList",
)

def _check_unique_name_fields(self, input_list: Iterable[object]) -> None:
"""Raise a ValueError if any value of the name_field attribute is used more than once in a list of class
Expand All @@ -281,6 +308,7 @@ def _check_unique_name_fields(self, input_list: Iterable[object]) -> None:
------
ValueError
Raised if the input list defines more than one object with the same value of name_field.

"""
names = [getattr(model, self.name_field) for model in input_list if hasattr(model, self.name_field)]
if len(set(names)) != len(names):
Expand All @@ -298,6 +326,7 @@ def _check_classes(self, input_list: Iterable[object]) -> None:
------
ValueError
Raised if the input list defines objects of different types.

"""
if not (all(isinstance(element, self._class_handle) for element in input_list)):
raise ValueError(f"Input list contains elements of type other than '{self._class_handle.__name__}'")
Expand All @@ -315,6 +344,7 @@ def _get_item_from_name_field(self, value: Union[object, str]) -> Union[object,
instance : object or str
Either the object with the value of the name_field attribute given by value, or the input value if an
object with that value of the name_field attribute cannot be found.

"""
return next((model for model in self.data if getattr(model, self.name_field) == value), value)

Expand All @@ -333,6 +363,7 @@ def _determine_class_handle(input_list: Sequence[object]):
class_handle : type
The type object of the element fulfilling the condition of satisfying "issubclass" for all of the other
elements.

"""
for this_element in input_list:
if all([issubclass(type(instance), type(this_element)) for instance in input_list]):
Expand Down
40 changes: 25 additions & 15 deletions RAT/controls.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from dataclasses import dataclass, field
import prettytable
from pydantic import BaseModel, Field, field_validator, ValidationError
from typing import Literal, Union

from RAT.utils.enums import Parallel, Procedures, Display, BoundHandling, Strategies
import prettytable
from pydantic import BaseModel, Field, ValidationError, field_validator

from RAT.utils.custom_errors import custom_pydantic_validation_error
from RAT.utils.enums import BoundHandling, Display, Parallel, Procedures, Strategies


@dataclass(frozen=True)
class Controls:
"""The full set of controls parameters required for the compiled RAT code."""

# All Procedures
procedure: Procedures = Procedures.Calculate
parallel: Parallel = Parallel.Single
Expand Down Expand Up @@ -44,8 +46,9 @@ class Controls:
adaptPCR: bool = False


class Calculate(BaseModel, validate_assignment=True, extra='forbid'):
class Calculate(BaseModel, validate_assignment=True, extra="forbid"):
"""Defines the class for the calculate procedure, which includes the properties used in all five procedures."""

procedure: Literal[Procedures.Calculate] = Procedures.Calculate
parallel: Parallel = Parallel.Single
calcSldDuringFit: bool = False
Expand All @@ -56,20 +59,21 @@ class Calculate(BaseModel, validate_assignment=True, extra='forbid'):
@classmethod
def check_resample_params(cls, resampleParams):
if not 0 < resampleParams[0] < 1:
raise ValueError('resampleParams[0] must be between 0 and 1')
raise ValueError("resampleParams[0] must be between 0 and 1")
if resampleParams[1] < 0:
raise ValueError('resampleParams[1] must be greater than or equal to 0')
raise ValueError("resampleParams[1] must be greater than or equal to 0")
return resampleParams

def __repr__(self) -> str:
table = prettytable.PrettyTable()
table.field_names = ['Property', 'Value']
table.field_names = ["Property", "Value"]
table.add_rows([[k, v] for k, v in self.__dict__.items()])
return table.get_string()


class Simplex(Calculate):
"""Defines the additional fields for the simplex procedure."""

procedure: Literal[Procedures.Simplex] = Procedures.Simplex
xTolerance: float = Field(1.0e-6, gt=0.0)
funcTolerance: float = Field(1.0e-6, gt=0.0)
Expand All @@ -81,6 +85,7 @@ class Simplex(Calculate):

class DE(Calculate):
"""Defines the additional fields for the Differential Evolution procedure."""

procedure: Literal[Procedures.DE] = Procedures.DE
populationSize: int = Field(20, ge=1)
fWeight: float = 0.5
Expand All @@ -92,6 +97,7 @@ class DE(Calculate):

class NS(Calculate):
"""Defines the additional fields for the Nested Sampler procedure."""

procedure: Literal[Procedures.NS] = Procedures.NS
nLive: int = Field(150, ge=1)
nMCMC: float = Field(0.0, ge=0.0)
Expand All @@ -101,6 +107,7 @@ class NS(Calculate):

class Dream(Calculate):
"""Defines the additional fields for the Dream procedure."""

procedure: Literal[Procedures.Dream] = Procedures.Dream
nSamples: int = Field(50000, ge=0)
nChains: int = Field(10, gt=0)
Expand All @@ -110,28 +117,31 @@ class Dream(Calculate):
adaptPCR: bool = False


def set_controls(procedure: Procedures = Procedures.Calculate, **properties)\
-> Union[Calculate, Simplex, DE, NS, Dream]:
def set_controls(
procedure: Procedures = Procedures.Calculate,
**properties,
) -> Union[Calculate, Simplex, DE, NS, Dream]:
"""Returns the appropriate controls model given the specified procedure."""
controls = {
Procedures.Calculate: Calculate,
Procedures.Simplex: Simplex,
Procedures.DE: DE,
Procedures.NS: NS,
Procedures.Dream: Dream
Procedures.Dream: Dream,
}

try:
model = controls[procedure](**properties)
except KeyError:
members = list(Procedures.__members__.values())
allowed_values = f'{", ".join([repr(member.value) for member in members[:-1]])} or {members[-1].value!r}'
raise ValueError(f'The controls procedure must be one of: {allowed_values}') from None
raise ValueError(f"The controls procedure must be one of: {allowed_values}") from None
except ValidationError as exc:
custom_error_msgs = {'extra_forbidden': f'Extra inputs are not permitted. The fields for the {procedure}'
f' controls procedure are:\n '
f'{", ".join(controls[procedure].model_fields.keys())}\n'
}
custom_error_msgs = {
"extra_forbidden": f'Extra inputs are not permitted. The fields for the {procedure}'
f' controls procedure are:\n '
f'{", ".join(controls[procedure].model_fields.keys())}\n',
}
custom_error_list = custom_pydantic_validation_error(exc.errors(), custom_error_msgs)
raise ValidationError.from_exception_data(exc.title, custom_error_list) from None

Expand Down
Loading