Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 123 additions & 117 deletions src/scores/spatial/cra_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,6 @@ def _translate_forecast_region(
# How can the following actually occur? What should users do about it? What does it mean?
# If we don't know, should we warn instead?
# If not, should we just let the user decide?

# if (
# rmse_shifted > rmse_original
# or corr_shifted < corr_original
Expand Down Expand Up @@ -388,7 +387,7 @@ def _calc_num_points(data: xr.DataArray, threshold: float) -> int:
>>> count = _calc_num_points(data, threshold=5.0)
"""
mask = data >= threshold
count_above_threshold = mask.sum().item()
count_above_threshold = int(mask.sum().compute())
return count_above_threshold


Expand Down Expand Up @@ -458,6 +457,13 @@ def _calc_resolution(obs: xr.DataArray, spatial_dims: list[str], units: str) ->
return float(avg_resolution_km)


def _collect_coords(da: xr.DataArray, y_name: str, x_name: str, time_name: Optional[str]):
coord_names = [y_name, x_name]
if time_name and time_name in da.coords:
coord_names.insert(0, time_name)
return {name: da[name] for name in coord_names}


def _cra_image(
fcst: xr.DataArray,
obs: xr.DataArray,
Expand Down Expand Up @@ -501,12 +507,39 @@ def _cra_image(
x_name (str): Name of the zonal spatial dimension (e.g., 'lon', 'projection_x_coordinate').
max_distance (float): Maximum allowed translation distance in kilometres.
min_points (int): Minimum number of grid points required in a blob.
coord_units (str) : Coordinate units, 'degrees' or 'metres'
coord_units (str) : Coordinate units; either 'degrees' or 'metres'
extra_components (bool) : If True, include extended diagnostics such as
max/mean rainfall in forecast/observation blobs, RMSE before and after
shifting, and correlation coefficients.
time_name (str): Name of the dimension to use for time-series (e.g. 'time', 'lead_time' or 'valid_time')

Returns:
`CRA2DMetric`: A dictionary containing the CRA components and diagnostics.
xr.Dataset: A dataset containing CRA metrics and diagnostics for the 2D slice.

**Core variables (always present)**:
- ``mse_total`` (float): Total mean squared error between forecast and observed blobs.
- ``mse_displacement`` (float): MSE due to spatial displacement between forecast and observed blobs.
- ``mse_volume`` (float): MSE due to volume differences.
- ``mse_pattern`` (float): MSE due to pattern/structure differences.
- ``mse_shift`` (float): MSE of the shifted forecast blob.

**Extended variables (included if ``extra_components=True``)**:
- ``optimal_shift`` (1D): Optimal shift [dx, dy] in the ``optimal_shit`` dimension in grid points units.
- ``fcst_blob`` (2D): Forecast blob values
- ``obs_blob`` (2D): Observation blob values
- ``shifted_fcst`` (2D): Shifted forecast blob values after optimal shift.
- ``num_gridpoints_above_threshold_fcst`` (int): Number of grid points in forecast blob above threshold.
- ``num_gridpoints_above_threshold_obs`` (int): Number of grid points in observed blob above threshold.
- ``avg_fcst`` (float): Mean value of the forecast blob.
- ``avg_obs`` (float): Mean value of the observed blob.
- ``max_fcst`` (float): Maximum value in the forecast blob.
- ``max_obs`` (float): Maximum value in the observed blob.
- ``corr_coeff_original`` (float): Correlation coefficient between original forecast and observed blobs.
- ``corr_coeff_shifted`` (float): Correlation coefficient between shifted forecast and observed blobs.
- ``rmse_original`` (float): Root mean square error between original forecast and observed blobs.
- ``rmse_shifted`` (float): Root mean square error between shifted forecast and observed blobs.


Returns an object containing NaNs if input data is invalid or CRA computation fails.

Raises:
ValueError: If input shapes do not match or blobs cannot be computed.
Expand All @@ -522,8 +555,12 @@ def _cra_image(
>>> import xarray as xr
>>> fcst = xr.DataArray(...) # your forecast data
>>> obs = xr.DataArray(...) # your observation data
>>> result = cra(fcst, obs, threshold=5.0)
>>> print(result['mse_total'])
>>> result = _cra_image(
... fcst, obs, threshold=5.0.
... y_name="projection_y_coordinate", x_name="projection_x_coordinate",
... extra_components=True,
... )
>>> float(result['mse_total'])

"""

Expand All @@ -532,11 +569,56 @@ def _cra_image(

fcst_blob, obs_blob = _generate_largest_rain_area_2d(fcst, obs, threshold, min_points)

# If either blob is missing/invalid, return a NaN dataset early
if fcst_blob is None or obs_blob is None:
return xr.Dataset(
data_vars={
"mse_total": xr.DataArray(np.nan),
"mse_shift": xr.DataArray(np.nan),
"mse_displacement": xr.DataArray(np.nan),
"mse_volume": xr.DataArray(np.nan),
"mse_pattern": xr.DataArray(np.nan),
},
coords=_collect_coords(obs, y_name, x_name, time_name),
attrs={
"cra_valid": False,
"reason": "no_valid_blob",
"threshold": float(threshold),
"coord_units": str(coord_units),
"max_distance": float(max_distance),
"min_points": int(min_points),
"extra_components": bool(extra_components),
},
)

mse_total = mse(fcst_blob, obs_blob)

shifted_fcst, delta_x, delta_y = _translate_forecast_region(
fcst_blob, obs_blob, y_name, x_name, max_distance, coord_units
)

if shifted_fcst is None or delta_x is None or delta_y is None:
# Return a consistent NaN dataset
return xr.Dataset(
data_vars={
"mse_total": xr.DataArray(np.nan),
"mse_shift": xr.DataArray(np.nan),
"mse_displacement": xr.DataArray(np.nan),
"mse_volume": xr.DataArray(np.nan),
"mse_pattern": xr.DataArray(np.nan),
},
coords=_collect_coords(obs, y_name, x_name, time_name),
attrs={
"cra_valid": False,
"reason": "no_valid_translation",
"threshold": float(threshold),
"coord_units": str(coord_units),
"max_distance": float(max_distance),
"min_points": int(min_points),
"extra_components": bool(extra_components),
},
)

optimal_shift = [delta_x, delta_y]

mse_shift = mse(shifted_fcst, obs_blob)
Expand Down Expand Up @@ -585,8 +667,8 @@ def _cra_image(


def cra(
fcst: XarrayLike,
obs: XarrayLike,
fcst: xr.DataArray,
obs: xr.DataArray,
threshold: float,
y_name: str,
x_name: str,
Expand Down Expand Up @@ -626,27 +708,36 @@ def cra(
max_distance (float): Maximum allowed translation distance in kilometres.
min_points (int): Minimum number of grid points required in a blob.
reduce_dims (list[str] or str, optional): Dimension to group by (default: ["time"]).
coord_units (str) : Coordinate units, 'degrees' or 'metres'
coord_units (str) : Coordinate units, either 'degrees' or 'metres'
extra_components (bool) : If True, include extended diagnostics such as
max/mean rainfall in forecast/observation blobs, RMSE before and after
shifting, and correlation coefficients.

Returns:
A dictionary where each key corresponds to a CRA metric and maps to a list of
values, one for each slice along the specified grouping dimension (e.g. time):
- mse_total (list[float]): Total mean squared error between forecast and observed blobs.
- mse_displacement (list[float]): MSE due to spatial displacement between forecast and observed blobs.
- mse_volume (list[float]): MSE due to volume differences.
- mse_pattern (list[float]): MSE due to pattern/structure differences.
- optimal_shift (list[list[int]]): Optimal [x, y] shift applied to forecast blob in grid points units.
- num_gridpoints_above_threshold_fcst (list[int]): Number of grid points in forecast blob above threshold.
- num_gridpoints_above_threshold_obs (list[int]): Number of grid points in observed blob above threshold.
- avg_fcst (list[float]): Mean value of the forecast blob.
- avg_obs (list[float]): Mean value of the observed blob.
- max_fcst (list[float]): Maximum value in the forecast blob.
- max_obs (list[float]): Maximum value in the observed blob.
- corr_coeff_original (list[float]): Correlation coefficient between original forecast and observed blobs.
- corr_coeff_shifted (list[float]): Correlation coefficient between shifted forecast and observed blobs.
- rmse_original (list[float]): Root mean square error between original forecast and observed blobs.
- rmse_shifted (list[float]): Root mean square error between shifted forecast and observed blobs.
Returns None if input data is invalid or CRA computation fails.
xr.Dataset: CRA metrics and diagnostics for each slice along the extra dimensions.

**Core variables (always present)**:
- ``mse_total`` (float): Total mean squared error between forecast and observed blobs.
- ``mse_displacement`` (float): MSE due to spatial displacement between forecast and observed blobs.
- ``mse_volume`` (float): MSE due to volume differences.
- ``mse_pattern`` (float): MSE due to pattern/structure differences.
- ``mse_shift`` (float): MSE of the shifted forecast blob.

**Extended variables (included if ``extra_components=True``)**:
- ``optimal_shift`` (1D): Optimal shift [dx, dy] in the ``optimal_shit`` dimension in grid points units.
- ``fcst_blob`` (2D): Forecast blob values
- ``obs_blob`` (2D): Observation blob values
- ``shifted_fcst`` (2D): Shifted forecast blob values after optimal shift.
- ``num_gridpoints_above_threshold_fcst`` (int): Number of grid points in forecast blob above threshold.
- ``num_gridpoints_above_threshold_obs`` (int): Number of grid points in observed blob above threshold.
- ``avg_fcst`` (float): Mean value of the forecast blob.
- ``avg_obs`` (float): Mean value of the observed blob.
- ``max_fcst`` (float): Maximum value in the forecast blob.
- ``max_obs`` (float): Maximum value in the observed blob.
- ``corr_coeff_original`` (float): Correlation coefficient between original forecast and observed blobs.
- ``corr_coeff_shifted`` (float): Correlation coefficient between shifted forecast and observed blobs.
- ``rmse_original`` (float): Root mean square error between original forecast and observed blobs.
- ``rmse_shifted`` (float): Root mean square error between shifted forecast and observed blobs.


Raises:
Expand All @@ -663,8 +754,10 @@ def cra(
>>> import xarray as xr
>>> fcst = xr.DataArray(...) # forecast with time dimension
>>> obs = xr.DataArray(...) # observation with time dimension
>>> result = cra(fcst, obs, threshold=5.0, reduce_dims="time")
>>> print(result["mse_total"])
>>> ds = cra(fcst, obs, threshold=5.0,
... y_name="projection_y_coordinate", x_name="projection_x_coordinate",
... extra_components=True)
>>> ds["mse_total"]
"""

# --- Input validation ---
Expand Down Expand Up @@ -738,90 +831,3 @@ def validate_cra2d_inputs(fcst, obs, time_name, coord_units, x_name, y_name):
allowed_units = ["degrees", "metres"]
if coord_units not in allowed_units:
raise ValueError(f"coord_units must be one of {allowed_units}")


# TODO: Merge the docs with cra_image and delete this method
def def_core_2d(
fcst: xr.DataArray,
obs: xr.DataArray,
threshold: float,
y_name: str,
x_name: str,
max_distance: float = 300,
min_points: int = 10,
coord_units: str = "metres",
time_name: Optional[str] = None, # Specify a length-of-one time dimension name
) -> Optional[dict]:
"""
Compute the core Contiguous Rain Area (CRA) decomposition between forecast and observation fields.

This function returns only the essential CRA decomposition components.
To obtain all CRA metrics and diagnostics, use :py:func:`scores.spatial.cra_image`. For time-dependent data,
use the :py:func:`scores.spatial.cra` function.

The core CRA score decomposes the total mean squared error (MSE) into three components:
displacement, volume, and pattern. It identifies contiguous rain blobs above a threshold,
shifts the forecast blob to best match the observed blob, and evaluates the error reduction.

See :py:func:`scores.spatial.cra_image` for more details.

Args:
fcst (xr.DataArray): Forecast field as an xarray DataArray.
obs (xr.DataArray): Observation field as an xarray DataArray.
threshold (float): Threshold to define contiguous rain areas.
y_name (str): Name of the meridional spatial dimension
(e.g., 'lat', 'projection_y_coordinate').
x_name (str): Name of the zonal spatial dimension
(e.g., 'lon', 'projection_x_coordinate').
time_name (Optional[str]): Name of the dimension to use for time-series (e.g. 'time', 'lead_time' or 'valid_time')
max_distance (float): Maximum allowed translation distance in kilometres.
min_points (int): Minimum number of grid points required in a blob.
coord_units (str) : Coordinate units, 'degrees' or 'metres'

Returns:
A CRAMetric instance containing the core CRA components
- mse_total (float): Total mean squared error between forecast and observed blobs.
- mse_displacement (float): MSE reduction due to spatial displacement after optimal alignment.
- mse_volume (float): MSE due to mean intensity (volume) differences.
- mse_pattern (float): Residual MSE after displacement and volume adjustment.
- optimal_shift (list[int]): Optimal [x, y] shift applied to the forecast blob (grid-point units).


Example:
>>> import xarray as xr
>>> from scores.spatial import cra_image
>>> fcst = xr.DataArray(...) # 2D forecast
>>> obs = xr.DataArray(...) # 2D observation
>>> result = cra_image(fcst, obs, threshold=5.0, y_name="lat", x_name="lon")
>>> print(result.mse_total)


"""

# Throw an exception if invalid input
validate_cra2d_inputs(fcst, obs, time_name, coord_units, x_name, y_name)

fcst_blob, obs_blob = _generate_largest_rain_area_2d(fcst, obs, threshold, min_points)
mse_total = float(mse(fcst_blob, obs_blob))

shifted_fcst, dx, dy = _translate_forecast_region(fcst_blob, obs_blob, y_name, x_name, max_distance, coord_units)
assert shifted_fcst is not None
assert obs_blob is not None
mse_shift = mse(shifted_fcst, obs_blob)
mse_volume = _calc_mse_volume(shifted_fcst, obs_blob)

data_vars = {
"mse_total": mse_total,
"mse_shift": mse_shift,
"mse_displacement": mse_total - mse_shift,
"mse_volume": mse_volume,
"mse_pattern": mse_shift - mse_volume,
}

coords = [x_name, y_name]
if time_name:
coords = [time, x_name, y_name]

ds = xr.Dataset(coords={name: obs[name] for name in coords}, data_vars=data_vars)

return ds