Skip to content
Merged
2 changes: 1 addition & 1 deletion disruption_py/core/physics_method/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def bind(
# Utility methods for decorated methods


def is_parametered_method(method: Callable) -> bool:
def is_physics_method(method: Callable) -> bool:
"""
Returns whether the method is decorated with `physics_method` decorator

Expand Down
2 changes: 0 additions & 2 deletions disruption_py/core/physics_method/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ class PhysicsMethodParams:
disruption_time: float
mds_conn: MDSConnection
times: np.ndarray
interpolation_method: Any # Fix
metadata: dict

def __post_init__(self):
self.logger = shot_msg_patch(logger, self.shot_id)
Expand Down
6 changes: 3 additions & 3 deletions disruption_py/core/physics_method/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from disruption_py.core.physics_method.metadata import (
BoundMethodMetadata,
get_method_metadata,
is_parametered_method,
is_physics_method,
)
from disruption_py.core.physics_method.params import PhysicsMethodParams
from disruption_py.core.utils.misc import get_elapsed_time
Expand All @@ -42,12 +42,12 @@ def get_all_physics_methods(all_passed: list) -> set:
"""
physics_methods = set()
for passed in all_passed:
if callable(passed) and is_parametered_method(passed):
if callable(passed) and is_physics_method(passed):
physics_methods.add(passed)

for method_name in dir(passed):
method = getattr(passed, method_name, None)
if method is None or not is_parametered_method(method):
if method is None or not is_physics_method(method):
continue
physics_methods.add(method)
return physics_methods
Expand Down
13 changes: 0 additions & 13 deletions disruption_py/core/retrieval_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from disruption_py.core.physics_method.params import PhysicsMethodParams
from disruption_py.core.physics_method.runner import populate_shot
from disruption_py.core.utils.math import interp1
from disruption_py.core.utils.misc import shot_msg
from disruption_py.inout.mds import MDSConnection, ProcessMDSConnection
from disruption_py.inout.sql import ShotDatabase
Expand Down Expand Up @@ -196,31 +195,19 @@ def setup_physics_method_params(
The configured physics method parameters.
"""

interpolation_method = interp1 # TODO: fix

times = self._init_times(
shot_id=shot_id,
mds_conn=mds_conn,
disruption_time=disruption_time,
retrieval_settings=retrieval_settings,
)

metadata = {
"labels": {},
"timestep": {},
"duration": {},
"description": "",
"disrupted": 100, # TODO: Fix
}

physics_method_params = PhysicsMethodParams(
shot_id=shot_id,
tokamak=self.tokamak,
disruption_time=disruption_time,
mds_conn=mds_conn,
times=times,
interpolation_method=interpolation_method,
metadata=metadata,
)

# Modify already existing shot properties, such as modifying timebase
Expand Down
26 changes: 0 additions & 26 deletions disruption_py/core/utils/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,6 @@
to corresponding enum values and to convert string values to enum values.
"""

from typing import Dict


def map_string_attributes_to_enum(obj, enum_translations: Dict):
"""
Map string attributes of an object to corresponding enum values.

Parameters
----------
obj : object
The object whose attributes will be mapped to enum values.
enum_translations : Dict[str, type]
A dictionary mapping attribute names to their corresponding enum classes.
"""
for field_name, enum_class in enum_translations.items():
if hasattr(obj, field_name) and not isinstance(
getattr(obj, field_name), enum_class
):
try:
enum_value = enum_class(getattr(obj, field_name))
setattr(obj, field_name, enum_value)
except ValueError as e:
raise ValueError(
f"Invalid enum value for field '{field_name}': {getattr(obj, field_name)}"
) from e


def map_string_to_enum(value, enum_class, should_raise=True):
"""
Expand Down
75 changes: 0 additions & 75 deletions disruption_py/core/utils/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,40 +87,6 @@ def interp1(x, y, new_x, kind="linear", bounds_error=False, fill_value=np.nan, a
return set_interp(new_x)


def exp_filter(x, w, strategy="fragmented"):
"""
Implements an exponential filter.

This function implements an exponential filter on the given array x. In the
case of nan values in the input array, we default to using the last timestep
that was not a nan value. In the fragmented strategy, any time we encounter
invald values, we restart the filter at the next valid value.

Parameters
----------
x : array
The array to filter.
w : float
The filter weight.
strategy: str, optional
Imputation strategy to be used, if any. Options are 'fragmented' or 'none'.
Default is 'fragmented.'

Returns
-------
_ : array
The filtered array.
"""
filtered_x = np.zeros(x.shape)
filtered_x[0] = x[0]
for i in range(1, len(x)):
filtered_x[i] = w * x[i] + (1 - w) * filtered_x[i - 1]
if strategy == "fragmented":
if np.isnan(filtered_x[i - 1]):
filtered_x[i] = x[i]
return filtered_x


def smooth(arr: np.ndarray, window_size: int) -> np.ndarray:
"""
Implements Matlab's smooth function https://www.mathworks.com/help/curvefit/smooth.html.
Expand All @@ -145,47 +111,6 @@ def smooth(arr: np.ndarray, window_size: int) -> np.ndarray:
return np.concatenate((start, mid, end))


def gauss_smooth(y, smooth_width, ends_type):
"""
Smooth a dataset using a Gaussian window.

Parameters
----------
y : array_like
The y coordinates of the dataset.
smooth_width : int
The width of the smoothing window.
ends_type : int
Determines how the "ends" of the signal are handled.
0 -> ends are "zeroed"
1 -> the ends are smoothed with progressively smaller smooths the closer to the end.

Returns
-------
array_like
The smoothed dataset.
"""
w = np.round(smooth_width)
w = int(w) # Ensure w is an integer
ly = len(y)
s = np.zeros(ly)

for i in range(ly):
if i < w // 2:
if ends_type == 0:
s[i] = 0
else:
s[i] = np.mean(y[: i + w // 2])
elif i >= ly - w // 2:
if ends_type == 0:
s[i] = 0
else:
s[i] = np.mean(y[i - w // 2 :])
else:
s[i] = np.mean(y[i - w // 2 : i + w // 2])
return s


@filter_cov_warning
def gaussian_fit(*args, **kwargs):
"""
Expand Down
17 changes: 0 additions & 17 deletions disruption_py/core/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,6 @@
from loguru import logger


def instantiate_classes(lst: List):
"""
Instantiate all classes in a list of classes and objects.

Parameters
----------
lst : List
List to instantiate classes from.

Returns
-------
List
The list with all classes instantiated.
"""
return [x() for x in lst if isinstance(x, type)]


def without_duplicates(lst: List):
"""
Get list without duplicates while maintaining order.
Expand Down
Loading