Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions modelskill/comparison/_collection_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class ComparerCollectionPlotter:
>>> cc.plot.hist()
>>> cc.plot.kde()
>>> cc.plot.taylor()
>>> cc.plot.box()
"""

def __init__(self, cc: ComparerCollection) -> None:
Expand Down Expand Up @@ -535,3 +536,61 @@ def taylor(
normalize_std=normalize_std,
title=title,
)

def box(self, *, ax=None, figsize=None, title=None, **kwargs) -> Axes:
"""Plot box plot of observations and model data.

Parameters
----------
ax : Axes, optional
matplotlib axes, by default None
figsize : tuple, optional
width and height of the figure, by default None
title : str, optional
plot title, by default None
**kwargs
passed to pandas.DataFrame.plot.box()

Returns
-------
Axes
matplotlib axes

Examples
--------
>>> cc.plot.box()
>>> cc.plot.box(showmeans=True)
>>> cc.plot.box(ax=ax, title="Box plot")
"""
_, ax = _get_fig_ax(ax, figsize)

df = self.cc._to_long_dataframe()

unique_obs_cols = ["time", "x", "y", "observation"]
df = df.set_index(unique_obs_cols)
unique_obs_values = df[~df.duplicated()].obs_val.values

data = {"Observation": unique_obs_values}
for model in df.model.unique():
df_model = df[df.model == model]
data[model] = df_model.mod_val.values

data = {k: pd.Series(v) for k, v in data.items()}
df = pd.DataFrame(data)

if "grid" not in kwargs:
kwargs["grid"] = True

ax = df.plot.box(ax=ax, **kwargs)

ax.set_ylabel(f"{self.cc._unit_text}")

title = (
_default_univarate_title("Box plot", self.cc) if title is None else title
)
ax.set_title(title)

if self.is_directional:
_ytick_directional(ax)

return ax
1 change: 1 addition & 0 deletions tests/test_comparercollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ def test_plots_directional(cc):
"kde",
"hist",
"taylor",
"box",
]
)
def cc_plot_function(cc, request):
Expand Down