Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions RATpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from RATpy import events, models
from RATpy.classlist import ClassList
from RATpy.controls import set_controls
from RATpy.controls import Controls
from RATpy.project import Project
from RATpy.run import run
from RATpy.utils import plotting

__all__ = ["ClassList", "Project", "run", "set_controls", "models", "events", "plotting"]
__all__ = ["models", "events", "ClassList", "Controls", "Project", "run", "plotting"]
20 changes: 16 additions & 4 deletions RATpy/classlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from collections.abc import Iterable, Sequence
from typing import Any, Union

import numpy as np
import prettytable


Expand Down Expand Up @@ -52,19 +53,30 @@ def __init__(self, init_list: Union[Sequence[object], object] = None, name_field

super().__init__(init_list)

def __repr__(self):
def __str__(self):
try:
[model.__dict__ for model in self.data]
except AttributeError:
output = repr(self.data)
output = str(self.data)
else:
if any(model.__dict__ for model in self.data):
table = prettytable.PrettyTable()
table.field_names = ["index"] + [key.replace("_", " ") for key in self.data[0].__dict__]
table.add_rows([[index] + list(model.__dict__.values()) for index, model in enumerate(self.data)])
table.add_rows(
[
[index]
+ list(
f"{'Data array: ['+' x '.join(str(i) for i in v.shape) if v.size > 0 else '['}]"
if isinstance(v, np.ndarray)
else str(v)
for v in model.__dict__.values()
)
for index, model in enumerate(self.data)
]
)
output = table.get_string()
else:
output = repr(self.data)
output = str(self.data)
return output

def __setitem__(self, index: int, item: object) -> None:
Expand Down
217 changes: 106 additions & 111 deletions RATpy/controls.py
Original file line number Diff line number Diff line change
@@ -1,148 +1,143 @@
from dataclasses import dataclass, field
from typing import Literal, Union
import warnings

import prettytable
from pydantic import BaseModel, Field, ValidationError, field_validator
from pydantic import (
BaseModel,
Field,
ValidationError,
ValidatorFunctionWrapHandler,
field_validator,
model_serializer,
model_validator,
)

from RATpy.utils.custom_errors import custom_pydantic_validation_error
from RATpy.utils.enums import BoundHandling, Display, Parallel, Procedures, Strategies


@dataclass(frozen=True)
class Controls:
"""The full set of controls parameters required for the compiled RAT code."""
common_fields = ["procedure", "parallel", "calcSldDuringFit", "resampleParams", "display"]
update_fields = ["updateFreq", "updatePlotFreq"]
fields = {
"calculate": common_fields,
"simplex": [*common_fields, "xTolerance", "funcTolerance", "maxFuncEvals", "maxIterations", *update_fields],
"de": [
*common_fields,
"populationSize",
"fWeight",
"crossoverProbability",
"strategy",
"targetValue",
"numGenerations",
*update_fields,
],
"ns": [*common_fields, "nLive", "nMCMC", "propScale", "nsTolerance"],
"dream": [*common_fields, "nSamples", "nChains", "jumpProbability", "pUnitGamma", "boundHandling", "adaptPCR"],
}


class Controls(BaseModel, validate_assignment=True, extra="forbid"):
"""The full set of controls parameters for all five procedures that are required for the compiled RAT code."""

# All Procedures
procedure: Procedures = Procedures.Calculate
parallel: Parallel = Parallel.Single
calcSldDuringFit: bool = False
resampleParams: list[float] = field(default_factory=list[0.9, 50.0])
display: Display = Display.Iter
# Simplex
xTolerance: float = 1.0e-6
funcTolerance: float = 1.0e-6
maxFuncEvals: int = 10000
maxIterations: int = 1000
updateFreq: int = -1
updatePlotFreq: int = 1
# DE
populationSize: int = 20
fWeight: float = 0.5
crossoverProbability: float = 0.8
strategy: Strategies = Strategies.RandomWithPerVectorDither.value
targetValue: float = 1.0
numGenerations: int = 500
# NS
nLive: int = 150
nMCMC: float = 0.0
propScale: float = 0.1
nsTolerance: float = 0.1
# Dream
nSamples: int = 20000
nChains: int = 10
jumpProbability: float = 0.5
pUnitGamma: float = 0.2
boundHandling: BoundHandling = BoundHandling.Reflect
adaptPCR: bool = True


class Calculate(BaseModel, validate_assignment=True, extra="forbid"):
"""Defines the class for the calculate procedure, which includes the properties used in all five procedures."""

procedure: Literal[Procedures.Calculate] = Procedures.Calculate
parallel: Parallel = Parallel.Single
calcSldDuringFit: bool = False
resampleParams: list[float] = Field([0.9, 50], min_length=2, max_length=2)
display: Display = Display.Iter

@field_validator("resampleParams")
@classmethod
def check_resample_params(cls, resampleParams):
if not 0 < resampleParams[0] < 1:
raise ValueError("resampleParams[0] must be between 0 and 1")
if resampleParams[1] < 0:
raise ValueError("resampleParams[1] must be greater than or equal to 0")
return resampleParams

def __repr__(self) -> str:
table = prettytable.PrettyTable()
table.field_names = ["Property", "Value"]
table.add_rows([[k, v] for k, v in self.__dict__.items()])
return table.get_string()


class Simplex(Calculate):
"""Defines the additional fields for the simplex procedure."""

procedure: Literal[Procedures.Simplex] = Procedures.Simplex
# Simplex
xTolerance: float = Field(1.0e-6, gt=0.0)
funcTolerance: float = Field(1.0e-6, gt=0.0)
maxFuncEvals: int = Field(10000, gt=0)
maxIterations: int = Field(1000, gt=0)
updateFreq: int = -1
updatePlotFreq: int = 1


class DE(Calculate):
"""Defines the additional fields for the Differential Evolution procedure."""

procedure: Literal[Procedures.DE] = Procedures.DE
# Simplex and DE
updateFreq: int = 1
updatePlotFreq: int = 20
# DE
populationSize: int = Field(20, ge=1)
fWeight: float = 0.5
crossoverProbability: float = Field(0.8, gt=0.0, lt=1.0)
strategy: Strategies = Strategies.RandomWithPerVectorDither
targetValue: float = Field(1.0, ge=1.0)
numGenerations: int = Field(500, ge=1)


class NS(Calculate):
"""Defines the additional fields for the Nested Sampler procedure."""

procedure: Literal[Procedures.NS] = Procedures.NS
# NS
nLive: int = Field(150, ge=1)
nMCMC: float = Field(0.0, ge=0.0)
propScale: float = Field(0.1, gt=0.0, lt=1.0)
nsTolerance: float = Field(0.1, ge=0.0)


class Dream(Calculate):
"""Defines the additional fields for the Dream procedure."""

procedure: Literal[Procedures.Dream] = Procedures.Dream
# Dream
nSamples: int = Field(20000, ge=0)
nChains: int = Field(10, gt=0)
jumpProbability: float = Field(0.5, gt=0.0, lt=1.0)
pUnitGamma: float = Field(0.2, gt=0.0, lt=1.0)
boundHandling: BoundHandling = BoundHandling.Reflect
adaptPCR: bool = True

@model_validator(mode="wrap")
def warn_setting_incorrect_properties(self, handler: ValidatorFunctionWrapHandler) -> "Controls":
"""Raise a warning if the user sets fields that apply to other procedures."""
model_input = self
try:
input_dict = model_input.__dict__
except AttributeError:
input_dict = model_input

validated_self = None
try:
validated_self = handler(self)
except ValidationError as exc:
procedure = input_dict.get("procedure", Procedures.Calculate)
custom_error_msgs = {
"extra_forbidden": f'Extra inputs are not permitted. The fields for the "{procedure}"'
f' controls procedure are:\n '
f'{", ".join(fields.get("procedure", []))}\n',
}
custom_error_list = custom_pydantic_validation_error(exc.errors(), custom_error_msgs)
raise ValidationError.from_exception_data(exc.title, custom_error_list) from None

if isinstance(model_input, validated_self.__class__):
# This is for changing fields in a defined model
changed_fields = [key for key in input_dict if input_dict[key] != validated_self.__dict__[key]]
elif isinstance(model_input, dict):
# This is for a newly-defined model
changed_fields = input_dict.keys()
else:
raise ValueError('The input to the "Controls" model is invalid.')

new_procedure = validated_self.procedure
allowed_fields = fields[new_procedure]
for field in changed_fields:
if field not in allowed_fields:
incorrect_procedures = [key for (key, value) in fields.items() if field in value]
warnings.warn(
f'\nThe current controls procedure is "{new_procedure}", but the property'
f' "{field}" applies instead to the {", ".join(incorrect_procedures)} procedure.\n\n'
f' The fields for the "{new_procedure}" controls procedure are:\n'
f' {", ".join(fields[new_procedure])}\n',
stacklevel=2,
)

return validated_self

@field_validator("resampleParams")
@classmethod
def check_resample_params(cls, values: list[float]) -> list[float]:
"""Make sure each of the two values of resampleParams satisfy their conditions."""
if not 0 < values[0] < 1:
raise ValueError("resampleParams[0] must be between 0 and 1")
if values[1] < 0:
raise ValueError("resampleParams[1] must be greater than or equal to 0")
return values

def set_controls(
procedure: Procedures = Procedures.Calculate,
**properties,
) -> Union[Calculate, Simplex, DE, NS, Dream]:
"""Returns the appropriate controls model given the specified procedure."""
controls = {
Procedures.Calculate: Calculate,
Procedures.Simplex: Simplex,
Procedures.DE: DE,
Procedures.NS: NS,
Procedures.Dream: Dream,
}
@model_serializer
def serialize(self):
"""Filter fields so only those applying to the chosen procedure are serialized."""
return {model_field: getattr(self, model_field) for model_field in fields[self.procedure]}

try:
model = controls[procedure](**properties)
except KeyError:
members = list(Procedures.__members__.values())
allowed_values = f'{", ".join([repr(member.value) for member in members[:-1]])} or {members[-1].value!r}'
raise ValueError(f"The controls procedure must be one of: {allowed_values}") from None
except ValidationError as exc:
custom_error_msgs = {
"extra_forbidden": f'Extra inputs are not permitted. The fields for the {procedure}'
f' controls procedure are:\n '
f'{", ".join(controls[procedure].model_fields.keys())}\n',
}
custom_error_list = custom_pydantic_validation_error(exc.errors(), custom_error_msgs)
raise ValidationError.from_exception_data(exc.title, custom_error_list) from None
def __repr__(self) -> str:
fields_repr = ", ".join(repr(v) if a is None else f"{a}={v!r}" for a, v in self.model_dump().items())
return f"{self.__repr_name__()}({fields_repr})"

return model
def __str__(self) -> str:
table = prettytable.PrettyTable()
table.field_names = ["Property", "Value"]
table.add_rows([[k, v] for k, v in self.model_dump().items()])
return table.get_string()
2 changes: 1 addition & 1 deletion RATpy/examples/absorption/absorption.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@
)

# Now make a controls block
controls = RAT.set_controls(parallel="contrasts", resampleParams=[0.9, 150.0])
controls = RAT.Controls(parallel="contrasts", resampleParams=[0.9, 150.0])

# Run the code and plot the results
problem, results = RAT.run(problem, controls)
Expand Down
2 changes: 1 addition & 1 deletion RATpy/examples/domains/domains_custom_XY.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
model=["Domain Layer"],
)

controls = RAT.set_controls()
controls = RAT.Controls()
problem, results = RAT.run(problem, controls)

RAT.plotting.plot_ref_sld(problem, results, True)
2 changes: 1 addition & 1 deletion RATpy/examples/domains/domains_custom_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
model=["Alloy domains"],
)

controls = RAT.set_controls()
controls = RAT.Controls()

problem, results = RAT.run(problem, controls)
RAT.plotting.plot_ref_sld(problem, results, True)
2 changes: 1 addition & 1 deletion RATpy/examples/domains/domains_standard_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@


# Now we can run our simulation as usual, and plot the results
controls = RAT.set_controls()
controls = RAT.Controls()

problem, results = RAT.run(problem, controls)
RAT.plotting.plot_ref_sld(problem, results, True)
2 changes: 1 addition & 1 deletion RATpy/examples/languages/run_custom_file_languages.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
path = pathlib.Path(__file__).parent.resolve()

project = setup_problem.make_example_problem()
controls = RAT.set_controls()
controls = RAT.Controls()

# Python
start = time.time()
Expand Down
2 changes: 1 addition & 1 deletion RATpy/examples/non_polarised/DSPC_custom_XY.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@
model=["DSPC Model"],
)

controls = RAT.set_controls()
controls = RAT.Controls()

problem, results = RAT.run(problem, controls)
RAT.plotting.plot_ref_sld(problem, results, True)
2 changes: 1 addition & 1 deletion RATpy/examples/non_polarised/DSPC_custom_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@
model=["DSPC Model"],
)

controls = RAT.set_controls()
controls = RAT.Controls()

problem, results = RAT.run(problem, controls)
RAT.plotting.plot_ref_sld(problem, results, True)
2 changes: 1 addition & 1 deletion RATpy/examples/non_polarised/DSPC_standard_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@
model=stack,
)

controls = RAT.set_controls()
controls = RAT.Controls()

problem, results = RAT.run(problem, controls)
RAT.plotting.plot_ref_sld(problem, results, True)
Loading