diff --git a/modelskill/matching.py b/modelskill/matching.py index 5d3a2c000..821c1c628 100644 --- a/modelskill/matching.py +++ b/modelskill/matching.py @@ -405,6 +405,14 @@ def match_space_time( observation=observation, mri=mri, spatial_tolerance=spatial_tolerance ) + # check that model and observation have non-overlapping variables + if overlapping_names := set(mri.data.data_vars).intersection( + set(data.data_vars) + ): + raise ValueError( + f"Model: '{mr.name}' and observation have overlapping variables: {overlapping_names}" + ) + # TODO: is name needed? for v in list(mri.data.data_vars): data[v] = mri.data[v] diff --git a/tests/test_match.py b/tests/test_match.py index 609d9c828..250b564c8 100644 --- a/tests/test_match.py +++ b/tests/test_match.py @@ -389,3 +389,52 @@ def test_specifying_mod_item_not_allowed_twice(o1, mr1): def test_bad_model_input(o1): with pytest.raises(ValueError, match="mod type"): ms.match(obs=o1, mod=None) + + +def test_obs_and_mod_can_not_have_same_aux_item_names(): + obs_df = pd.DataFrame( + {"wl": [1.0, 2.0, 3.0], "wind_speed": [1.0, 2.0, 3.0]}, + index=pd.date_range("2017-01-01", periods=3), + ) + + mod_df = pd.DataFrame( + {"wl": [1.1, 2.0, 3.0], "wind_speed": [0.0, 0.0, 0.0]}, + index=pd.date_range("2017-01-01", periods=3), + ) + + obs = ms.PointObservation(obs_df, item="wl", aux_items=["wind_speed"]) + mod = ms.PointModelResult(mod_df, item="wl", aux_items=["wind_speed"]) + + with pytest.raises(ValueError, match="wind_speed"): + ms.match(obs=obs, mod=mod) + + +def test_mod_aux_items_must_be_unique(): + obs_df = pd.DataFrame( + {"wl": [1.0, 2.0, 3.0], "wind_speed": [1.0, 2.0, 3.0]}, + index=pd.date_range("2017-01-01", periods=3), + ) + + mod_df = pd.DataFrame( + {"wl": [1.1, 2.0, 3.0], "wind_speed": [0.0, 0.0, 0.0]}, + index=pd.date_range("2017-01-01", periods=3), + ) + + mod2_df = pd.DataFrame( + {"wl": [1.2, 2.1, 3.1], "wind_speed": [0.0, 0.0, 0.0]}, + index=pd.date_range("2017-01-01", periods=3), + ) + + obs = ms.PointObservation(obs_df, item="wl") + mod = ms.PointModelResult(mod_df, item="wl", aux_items=["wind_speed"], name="local") + + # this is ok + mod2 = ms.PointModelResult( + mod2_df, item="wl", aux_items=["wind_speed"], name="remote" + ) + + with pytest.raises(ValueError) as e: + ms.match(obs=obs, mod=[mod, mod2]) + + assert "wind_speed" in str(e.value) + assert "remote" in str(e.value)