Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions fmskill/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,6 +1092,7 @@ def scatter(
*,
bins: Union[int, float, List[int], List[float]] = 20,
quantiles: Union[int, List[float]] = None,
fit_to_quantiles: bool = False,
show_points: Union[bool, int, float] = None,
show_hist: bool = None,
show_density: bool = None,
Expand Down Expand Up @@ -1126,6 +1127,9 @@ def scatter(
number of quantiles for QQ-plot, by default None and will depend on the scatter data length (10, 100 or 1000)
if int, this is the number of points
if sequence (list of floats), represents the desired quantiles (from 0 to 1)
fit_to_quantiles: bool, optional, by default False
by default the regression line is fitted to all data, if True, it is fitted to the quantiles
which can be useful to represent the extremes of the distribution
show_points : (bool, int, float), optional
Should the scatter points be displayed?
None means: show all points if fewer than 1e4, otherwise show 1e4 sample points, by default None.
Expand Down Expand Up @@ -1217,6 +1221,7 @@ def scatter(
y=y,
bins=bins,
quantiles=quantiles,
fit_to_quantiles=fit_to_quantiles,
show_points=show_points,
show_hist=show_hist,
show_density=show_density,
Expand Down Expand Up @@ -2028,6 +2033,7 @@ def scatter(
*,
bins: Union[int, float, List[int], List[float]] = 20,
quantiles: Union[int, List[float]] = None,
fit_to_quantiles: bool = False,
show_points: Union[bool, int, float] = None,
show_hist: bool = None,
show_density: bool = None,
Expand Down Expand Up @@ -2064,6 +2070,9 @@ def scatter(
number of quantiles for QQ-plot, by default None and will depend on the scatter data length (10, 100 or 1000)
if int, this is the number of points
if sequence (list of floats), represents the desired quantiles (from 0 to 1)
fit_to_quantiles: bool, optional, by default False
by default the regression line is fitted to all data, if True, it is fitted to the quantiles
which can be useful to represent the extremes of the distribution
show_points : (bool, int, float), optional
Should the scatter points be displayed?
None means: show all points if fewer than 1e4, otherwise show 1e4 sample points, by default None.
Expand Down Expand Up @@ -2170,6 +2179,7 @@ def scatter(
y=y,
bins=bins,
quantiles=quantiles,
fit_to_quantiles=fit_to_quantiles,
show_points=show_points,
show_hist=show_hist,
show_density=show_density,
Expand Down
28 changes: 20 additions & 8 deletions fmskill/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def scatter(
*,
bins: Union[int, float, List[int], List[float]] = 20,
quantiles: Union[int, List[float]] = None,
fit_to_quantiles: bool = False,
show_points: Union[bool, int, float] = None,
show_hist: bool = None,
show_density: bool = None,
Expand Down Expand Up @@ -98,6 +99,9 @@ def scatter(
number of quantiles for QQ-plot, by default None and will depend on the scatter data length (10, 100 or 1000)
if int, this is the number of points
if sequence (list of floats), represents the desired quantiles (from 0 to 1)
fit_to_quantiles: bool, optional, by default False
by default the regression line is fitted to all data, if True, it is fitted to the quantiles
which can be useful to represent the extremes of the distribution
show_points : (bool, int, float), optional
Should the scatter points be displayed?
None means: show all points if fewer than 1e4, otherwise show 1e4 sample points, by default None.
Expand Down Expand Up @@ -264,7 +268,10 @@ def scatter(
z = z * len(x) / len(x_sample)

# linear fit
slope, intercept = _linear_regression(obs=x, model=y, reg_method=reg_method)
if fit_to_quantiles:
slope, intercept = _linear_regression(obs=xq, model=yq, reg_method=reg_method)
else:
slope, intercept = _linear_regression(obs=x, model=y, reg_method=reg_method)

if intercept < 0:
sign = ""
Expand Down Expand Up @@ -310,13 +317,16 @@ def scatter(
markersize=options.plot.scatter.quantiles.markersize,
**settings.get_option("plot.scatter.quantiles.kwargs"),
)

x_trend = xq if fit_to_quantiles else x_trend
plt.plot(
x_trend,
intercept + slope * x_trend,
**settings.get_option("plot.scatter.reg_line.kwargs"),
label=reglabel,
zorder=2,
)

if show_hist:
plt.hist2d(x, y, bins=nbins_hist, cmin=0.01, zorder=0.5, **kwargs)

Expand Down Expand Up @@ -345,18 +355,20 @@ def scatter(
import plotly.graph_objects as go

data = [
go.Scatter(
x=x,
y=intercept + slope * x,
name=reglabel,
mode="lines",
line=dict(color="red"),
),
go.Scatter(
x=xlim, y=xlim, name="1:1", mode="lines", line=dict(color="blue")
),
]

regression_line = go.Scatter(
x=x_trend,
y=intercept + slope * x_trend,
name=reglabel,
mode="lines",
line=dict(color="red"),
)
data.append(regression_line)

if show_hist:
data.append(
go.Histogram2d(
Expand Down
Loading