Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions src/modelskill/comparison/_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,17 +759,12 @@ def merge(
Comparer or ComparerCollection
New object with merged data.
"""
from ..matching import match_space_time
from ._collection import ComparerCollection

if isinstance(other, Comparer) and (self.name == other.name):
raw_mod_data = self.raw_mod_data.copy()
raw_mod_data.update(other.raw_mod_data) # TODO!
matched = match_space_time(
observation=self._to_observation(),
raw_mod_data=raw_mod_data, # type: ignore
)
assert matched is not None
matched = self.data.merge(other.data).dropna(dim="time")
cmp = Comparer(matched_data=matched, raw_mod_data=raw_mod_data)

return cmp
Expand Down
74 changes: 35 additions & 39 deletions src/modelskill/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

from . import Quantity, __version__, model_result
from .comparison import Comparer, ComparerCollection
from .model._base import Alignable
from .model.dfsu import DfsuModelResult
from .model.dummy import DummyModelResult
from .model.grid import GridModelResult
Expand Down Expand Up @@ -199,6 +198,7 @@ def match(
gtype=None,
max_model_gap=None,
spatial_method: Optional[str] = None,
spatial_tolerance: float = 1e-3,
obs_no_overlap: Literal["ignore", "error", "warn"] = "error",
):
"""Match observation and model result data in space and time
Expand Down Expand Up @@ -233,6 +233,11 @@ def match(
'inverse_distance' (with 5 nearest points), by default "inverse_distance".
- For GridModelResult, passed to xarray.interp() as method argument,
by default 'linear'.
spatial_tolerance : float, optional
Spatial tolerance (in the units of the coordinate system) for matching
model track points to observation track points. Model points outside
this tolerance will be discarded. Only relevant for TrackModelResult
and TrackObservation, by default 1e-3.
obs_no_overlap: str, optional
How to handle observations with no overlap with model results. One of: 'ignore', 'error', 'warn', by default 'error'.

Expand All @@ -256,6 +261,7 @@ def match(
gtype=gtype,
max_model_gap=max_model_gap,
spatial_method=spatial_method,
spatial_tolerance=spatial_tolerance,
obs_no_overlap=obs_no_overlap,
)

Expand Down Expand Up @@ -292,6 +298,7 @@ def match(
gtype=gtype,
max_model_gap=max_model_gap,
spatial_method=spatial_method,
spatial_tolerance=spatial_tolerance,
obs_no_overlap=obs_no_overlap,
)
for o in obs
Expand All @@ -306,13 +313,15 @@ def _match_single_obs(
obs: ObsInputType,
mod: Union[MRInputType, Sequence[MRInputType]],
*,
obs_item: Optional[int | str] = None,
mod_item: Optional[int | str] = None,
gtype: Optional[GeometryTypes] = None,
max_model_gap: Optional[float] = None,
spatial_method: Optional[str] = None,
obs_no_overlap: Literal["ignore", "error", "warn"] = "error",
) -> Optional[Comparer]:
obs_item: int | str | None,
mod_item: int | str | None,
gtype: GeometryTypes | None,
max_model_gap: float | None,
spatial_method: str | None,
spatial_tolerance: float,
obs_no_overlap: Literal["ignore", "error", "warn"],
) -> Comparer | None:
# TODO passing gtype to this function is inconsistent with `match` docstring, where gtype is the geometry type of model result
observation = _parse_single_obs(obs, obs_item, gtype=gtype)

if isinstance(mod, get_args(MRInputType)):
Expand All @@ -334,11 +343,12 @@ def _match_single_obs(
for m in model_results
}

matched_data = match_space_time(
matched_data = _match_space_time(
observation=observation,
raw_mod_data=raw_mod_data,
max_model_gap=max_model_gap,
obs_no_overlap=obs_no_overlap,
spatial_tolerance=spatial_tolerance,
)
if matched_data is None:
return None
Expand All @@ -359,36 +369,13 @@ def _get_global_start_end(idxs: Iterable[pd.DatetimeIndex]) -> Period:
return Period(start=min(starts), end=max(ends))


def match_space_time(
def _match_space_time(
observation: Observation,
raw_mod_data: Mapping[str, Alignable],
max_model_gap: float | None = None,
obs_no_overlap: Literal["ignore", "error", "warn"] = "error",
raw_mod_data: Mapping[str, PointModelResult | TrackModelResult],
max_model_gap: float | None,
spatial_tolerance: float,
obs_no_overlap: Literal["ignore", "error", "warn"],
) -> Optional[xr.Dataset]:
"""Match observation with one or more model results in time domain.

and return as xr.Dataset in the format used by modelskill.Comparer

Will interpolate model results to observation time.

Note: assumes that observation and model data are already matched in space.
But positions of track observations will be checked.

Parameters
----------
observation : Observation
Observation to be matched
raw_mod_data : Mapping[str, Alignable]
Mapping of model results ready for interpolation
max_model_gap : Optional[TimeDeltaTypes], optional
In case of non-equidistant model results (e.g. event data),
max_model_gap can be given e.g. as seconds, by default None

Returns
-------
xr.Dataset or None
Matched data in the format used by modelskill.Comparer
"""
idxs = [m.time for m in raw_mod_data.values()]
period = _get_global_start_end(idxs)

Expand All @@ -401,8 +388,17 @@ def match_space_time(
data = data.rename({observation.name: "Observation"})

for mr in raw_mod_data.values():
# TODO is `align` the correct name for this operation?
aligned = mr.align(observation, max_gap=max_model_gap)
match mr, observation:
case TrackModelResult() as tmr, TrackObservation():
aligned = tmr.subset_to(
observation, spatial_tolerance=spatial_tolerance
)
case PointModelResult() as pmr, PointObservation():
aligned = pmr.align(observation, max_gap=max_model_gap)
case _:
raise TypeError(
f"Matching not implemented for model type {type(mr)} and observation type {type(observation)}"
)

if overlapping := set(aligned.filter_by_attrs(kind="aux").data_vars) & set(
observation.data.filter_by_attrs(kind="aux").data_vars
Expand Down
16 changes: 1 addition & 15 deletions src/modelskill/model/_base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from __future__ import annotations
from collections import Counter
from typing import Any, List, Optional, Protocol, Sequence, TYPE_CHECKING
from typing import List, Optional, Protocol, Sequence, TYPE_CHECKING
from dataclasses import dataclass
import warnings

import pandas as pd

if TYPE_CHECKING:
import xarray as xr
from .point import PointModelResult
from .track import TrackModelResult

Expand Down Expand Up @@ -86,16 +85,3 @@ def _extract_point(
def _extract_track(
self, observation: TrackObservation, spatial_method: Optional[str] = None
) -> TrackModelResult: ...


class Alignable(Protocol):
@property
def time(self) -> pd.DatetimeIndex: ...

def align(
self,
observation: Observation,
**kwargs: Any,
) -> xr.Dataset: ...

# the attributues of the returned dataset have additional requirements, but we can't express that here
3 changes: 1 addition & 2 deletions src/modelskill/model/point.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
from ..types import PointType
from ..quantity import Quantity
from ..timeseries import TimeSeries, _parse_point_input
from ._base import Alignable


class PointModelResult(TimeSeries, Alignable):
class PointModelResult(TimeSeries):
"""Model result for a single point location.

Construct a PointModelResult from a 0d data source:
Expand Down
13 changes: 6 additions & 7 deletions src/modelskill/model/track.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
from __future__ import annotations
from typing import Any, Literal, Optional, Sequence
from typing import Literal, Optional, Sequence
import warnings

import numpy as np
import xarray as xr

from ..obs import Observation
from ..types import TrackType
from ..obs import TrackObservation
from ..quantity import Quantity
from ..timeseries import TimeSeries, _parse_track_input
from ._base import Alignable


class TrackModelResult(TimeSeries, Alignable):
class TrackModelResult(TimeSeries):
"""Model result for a track.

Construct a TrackModelResult from a dfs0 file,
Expand Down Expand Up @@ -72,9 +71,9 @@ def __init__(
data[data_var].attrs["kind"] = "model"
super().__init__(data=data)

def align(self, observation: Observation, **kwargs: Any) -> xr.Dataset:
spatial_tolerance = 1e-3

def subset_to(
self, observation: TrackObservation, *, spatial_tolerance: float
) -> xr.Dataset:
mri = self
mod_df = mri.data.to_dataframe()
obs_df = observation.data.to_dataframe()
Expand Down
17 changes: 17 additions & 0 deletions tests/test_simple_compare.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import mikeio
from datetime import datetime
import pandas as pd
import modelskill as ms


Expand Down Expand Up @@ -117,3 +118,19 @@ def test_compare_obs_item_pointobs_inconsistent_item_error(fn_mod):
def test_force_keyword_args(fn_obs, fn_mod):
with pytest.raises(TypeError):
ms.match(fn_obs, fn_mod, 0, 0)


def test_matching_pointobservation_with_trackmodelresult_is_not_possible():
# ignore the data
tdf = pd.DataFrame(
{"x": [1, 2], "y": [1, 2], "m1": [0, 0]},
index=pd.date_range("2017-10-27 13:00:01", periods=2, freq="4S"),
)
mr = ms.TrackModelResult(tdf, item="m1", x_item="x", y_item="y")
pdf = pd.DataFrame(
data={"level": [0.0, 0.0]},
index=pd.date_range("2017-10-27 13:00:01", periods=2, freq="4S"),
)
obs = ms.PointObservation(pdf, item="level")
with pytest.raises(TypeError, match="TrackModelResult"):
ms.match(obs=obs, mod=mr)
2 changes: 1 addition & 1 deletion tests/test_trackcompare.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def test_tiny_mod_xy_difference(obs_tiny_df, mod_tiny_unique):
)
with pytest.warns(UserWarning, match="Removed 2 model points"):
# 2 points removed due to difference in x,y
cmp = ms.match(obs_tiny, mod_tiny_unique)
cmp = ms.match(obs_tiny, mod_tiny_unique, spatial_tolerance=1e-3)
assert cmp.n_points == 2
expected_time = pd.DatetimeIndex(
[
Expand Down