Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion modelskill/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ def __add__(
cmp.data = cmp.data.isel(time=index)

else:
cols = ["x", "y"] if isinstance(self, TrackComparer) else []
cols = ["x", "y"]
mod_data = [self.data[cols + [m]] for m in self.mod_names]
for m in other.mod_names:
mod_data.append(other.data[cols + [m]])
Expand Down
14 changes: 13 additions & 1 deletion tests/test_comparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,18 @@ def test_pc_query_empty(pc):
assert pc2.n_points == 0


def test_add_pc_tc(pc, tc):
cc = pc + tc
assert cc.n_points == 10
assert cc.n_comparers == 2


def test_add_tc_pc(pc, tc):
cc = tc + pc
assert cc.n_points == 10
assert cc.n_comparers == 2


def test_pc_to_dataframe(pc):
df = pc.to_dataframe()
assert isinstance(df, pd.DataFrame)
Expand Down Expand Up @@ -372,4 +384,4 @@ def test_pc_to_dataframe_add_col(pc):
assert isinstance(df, pd.DataFrame)
assert df.shape == (10, 7)
assert "derived" in df.columns
assert df.derived.dtype == "float64"
assert df.derived.dtype == "float64"
30 changes: 30 additions & 0 deletions tests/test_comparercollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,33 @@ def test_hist(cc):
def test_kde(cc):
ax = cc.kde()
assert ax is not None



def test_add_cc_pc(cc, pc):
pc2 = pc.copy()
pc2.data.attrs["name"] = "pc2"
cc2 = cc + pc2
assert cc2.n_points == 15
assert cc2.n_comparers == 3


def test_add_cc_tc(cc, tc):
tc2 = tc.copy()
tc2.data.attrs["name"] = "tc2"
cc2 = cc + tc2
assert cc2.n_points == 15
assert cc2.n_comparers == 3


def test_add_cc_cc(cc, pc, tc):
pc2 = pc.copy()
pc2.data.attrs["name"] = "pc2"
tc2 = tc.copy()
tc2.data.attrs["name"] = "tc2"
tc3 = tc.copy() # keep name
cc2 = pc2 + tc2 + tc3

cc3 = cc + cc2
#assert cc3.n_points == 15
assert cc3.n_comparers == 4