Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 1 addition & 51 deletions disruption_py/core/physics_method/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@

import functools
import threading
from typing import Callable, List
from typing import Callable

import numpy as np
import pandas as pd

from disruption_py.core.physics_method.params import PhysicsMethodParams

Expand Down Expand Up @@ -89,52 +88,3 @@ def get_method_cache_key(
len(times),
hashable_other_params,
)


def manually_cache(
physics_method_params: PhysicsMethodParams,
data: pd.DataFrame,
method: Callable,
method_name: str,
method_columns: List[str],
) -> bool:
"""
Manually cache results based on the provided DataFrame and method details.

Parameters
----------
physics_method_params : PhysicsMethodParams
The parameters containing the shot ID and logger for logging.
data : pd.DataFrame
The DataFrame containing the data to be cached.
method : Callable
The method for which the results are being cached.
method_name : str
The name of the method being cached, used for logging.
method_columns : List[str]
The list of columns to check and cache.

Returns
-------
bool
True if caching was successful, False if there were missing columns.
"""
if method_columns is None:
return False
if not hasattr(physics_method_params, "cached_results"):
physics_method_params.cached_results = {}
missing_columns = set(col for col in method_columns if col not in data.columns)
if len(missing_columns) == 0:
cache_key = get_method_cache_key(method, data["time"].values)
physics_method_params.cached_results[cache_key] = data[method_columns]
physics_method_params.logger.debug(
"Manually caching {method_name}",
method_name=method_name,
)
return True
physics_method_params.logger.debug(
"Can not cache {method_name} missing columns {missing_columns}",
method_name=method_name,
missing_columns=",".join(missing_columns),
)
return False
3 changes: 0 additions & 3 deletions disruption_py/core/physics_method/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import Any, Dict

import numpy as np
import pandas as pd
from loguru import logger

from disruption_py.core.utils.misc import shot_msg_patch
Expand All @@ -28,8 +27,6 @@ class PhysicsMethodParams:
disruption_time: float
mds_conn: MDSConnection
times: np.ndarray
cache_data: pd.DataFrame
pre_filled_shot_data: pd.DataFrame
interpolation_method: Any # Fix
metadata: dict

Expand Down
132 changes: 33 additions & 99 deletions disruption_py/core/physics_method/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
import pandas as pd
from MDSplus import mdsExceptions

from disruption_py.config import config
from disruption_py.core.physics_method.caching import manually_cache
from disruption_py.core.physics_method.errors import CalculationError
from disruption_py.core.physics_method.metadata import (
BoundMethodMetadata,
Expand All @@ -28,48 +26,6 @@
REQUIRED_COLS = {"shot", "time"}


def get_prefilled_shot_data(physics_method_params: PhysicsMethodParams) -> pd.DataFrame:
"""
Retrieve pre-filled shot data for the given physics method parameters.

Parameters
----------
physics_method_params : PhysicsMethodParams
Parameters containing MDS connection and shot information

Returns
-------
pd.DataFrame
A DataFrame containing the pre-filled shot data.
"""
pre_filled_shot_data = physics_method_params.pre_filled_shot_data

# If the shot object was already passed data in the constructor, use that data.
# Otherwise, create an empty dataframe.
if pre_filled_shot_data is None:
pre_filled_shot_data = pd.DataFrame()
if "time" not in pre_filled_shot_data:
pre_filled_shot_data["time"] = physics_method_params.times
if "shot" not in pre_filled_shot_data:
pre_filled_shot_data["shot"] = int(physics_method_params.shot_id)

# Check that pre_filled_shot_data is on the same timebase as the shot object
# to ensure data consistency
if (
len(pre_filled_shot_data["time"]) != len(physics_method_params.times)
or not np.isclose(
pre_filled_shot_data["time"],
physics_method_params.times,
atol=config().time_const,
).all()
):
physics_method_params.logger.error(
"Computation on different timebase than pre-filled shot data",
physics_method_params.shot_id,
)
return pre_filled_shot_data


def get_all_physics_methods(all_passed: list) -> set:
"""
Retrieve all callable physics methods from the provided list.
Expand Down Expand Up @@ -172,22 +128,22 @@ def filter_methods_to_run(

both_none = methods is None and columns is None
method_specified = methods is not None and bound_method_metadata.name in methods
column_speficied = columns is not None and bool(
column_specified = columns is not None and bool(
set(bound_method_metadata.columns).intersection(columns)
)
is_not_excluded = (
only_excluded_methods_specified
and not columns
and methods
and (("~" + bound_method_metadata.name) not in methods)
and f"~{bound_method_metadata.name}" not in methods
)
should_run = (
both_none or method_specified or column_speficied or is_not_excluded
both_none or method_specified or column_specified or is_not_excluded
)

# reasons that methods should be exluded from should run
# reasons that methods should be excluded from should run
should_not_run = (
methods is not None and ("~" + bound_method_metadata.name) in methods
methods is not None and f"~{bound_method_metadata.name}" in methods
)

if should_run and not should_not_run:
Expand Down Expand Up @@ -222,7 +178,6 @@ def populate_method(
name = bound_method_metadata.name

physics_method_params.logger.trace("Starting method: {name}", name=name)
result = None

try:
result = method(params=physics_method_params)
Expand Down Expand Up @@ -290,42 +245,19 @@ def populate_shot(
all_bound_method_metadata, retrieval_settings, physics_method_params
)

pre_filled_shot_data = get_prefilled_shot_data(physics_method_params)
# Manually cache data that has already been retrieved (likely from SQL tables)
# Methods added to pre_cached_method_names will be skipped by method optimizer
cached_method_metadata = []
if physics_method_params.pre_filled_shot_data is not None:
for method_metadata in all_bound_method_metadata:
cache_success = manually_cache(
physics_method_params=physics_method_params,
data=pre_filled_shot_data,
method=method_metadata.bound_method,
method_name=method_metadata.name,
method_columns=method_metadata.columns,
)
if cache_success:
cached_method_metadata.append(method_metadata)
if method_metadata in run_bound_method_metadata:
physics_method_params.logger.verbose(
"Cached method: {name}",
name=method_metadata.name,
)

# run methods and collect data
start_time = time.time()
methods_data = []
for bound_method_metadata in run_bound_method_metadata:
if bound_method_metadata in cached_method_metadata:
continue
methods_data.append(
populate_method(
physics_method_params=physics_method_params,
bound_method_metadata=bound_method_metadata,
)
methods_data = [
populate_method(
physics_method_params=physics_method_params,
bound_method_metadata=bound_method_metadata,
)
for bound_method_metadata in run_bound_method_metadata
]

# Initialize with cached data
num_parameters = len(pre_filled_shot_data.columns)
num_valid = pre_filled_shot_data.notna().any().sum()
# create DataFrames of proper shape
num_parameters = 0
num_valid = 0
filtered_methods = []
for method_dict in methods_data:
if method_dict is None:
Expand All @@ -338,20 +270,19 @@ def populate_shot(
np.all(np.isnan(method_dict[parameter]))
and len(method_dict[parameter]) == 1
):
method_dict[parameter] = np.full(len(pre_filled_shot_data), np.nan)
method_dict[parameter] = physics_method_params.times * np.nan
else:
num_valid += 1
method_df = pd.DataFrame(method_dict)
if len(method_df) != len(pre_filled_shot_data):
if len(method_df) != len(physics_method_params.times):
physics_method_params.logger.error(
"Ignoring parameters {parameter} with different length than timebase",
parameter=list(method_dict.keys()),
)
# TODO: Should we drop the columns, or is it better to raise an
# exception when the data do not match?
continue
filtered_methods.append(method_df)

# log statistics
percent_valid = (num_valid / num_parameters * 100) if num_parameters else 0
if percent_valid >= 75:
level = "SUCCESS"
Expand All @@ -375,21 +306,24 @@ def populate_shot(
elapsed=get_elapsed_time(time.time() - start_time),
)

# TODO: This is a hack to get around the fact that some methods return
# multiple parameters. This should be fixed in the future.

local_data = pd.concat([pre_filled_shot_data] + filtered_methods, axis=1)
# concatenate partial DataFrames
coords = pd.DataFrame(
{
"shot": [physics_method_params.shot_id] * len(physics_method_params.times),
"time": physics_method_params.times,
}
)
local_data = pd.concat([coords] + filtered_methods, axis=1)
local_data = local_data.loc[:, ~local_data.columns.duplicated()]

# include requested columns
include_cols = set(local_data.columns).difference(REQUIRED_COLS)
if (
retrieval_settings.only_requested_columns
and retrieval_settings.run_columns is not None
):
include_columns = list(
REQUIRED_COLS.union(
set(retrieval_settings.run_columns).intersection(
set(local_data.columns)
)
)
)
local_data = local_data[include_columns]
include_cols = set(retrieval_settings.run_columns).intersection(include_cols)

# sort columns
local_data = local_data[list(REQUIRED_COLS) + sorted(list(include_cols))]
return local_data
Loading