diff --git a/plotly_express/__init__.py b/plotly_express/__init__.py
index bb0e36a..6686d45 100644
--- a/plotly_express/__init__.py
+++ b/plotly_express/__init__.py
@@ -31,7 +31,12 @@
density_contour,
)
-from ._core import ExpressFigure, set_mapbox_access_token, defaults # noqa: F401
+from ._core import ( # noqa: F401
+ ExpressFigure,
+ set_mapbox_access_token,
+ defaults,
+ get_trendline_results,
+)
from . import data, colors # noqa: F401
@@ -61,5 +66,6 @@
"data",
"colors",
"set_mapbox_access_token",
+ "get_trendline_results",
"ExpressFigure",
]
diff --git a/plotly_express/_core.py b/plotly_express/_core.py
index 20e05c5..0cfeb31 100644
--- a/plotly_express/_core.py
+++ b/plotly_express/_core.py
@@ -2,7 +2,7 @@
from plotly.offline import init_notebook_mode, iplot
from collections import namedtuple, OrderedDict
from .colors import qualitative, sequential
-import math
+import math, pandas
class PxDefaults(object):
@@ -55,6 +55,21 @@ def _ipython_display_(self):
iplot(self, show_link=False, auto_play=False)
+def get_trendline_results(fig):
+ """
+ Extracts fit statistics for trendlines (when applied to figures generated with
+ the `trendline` argument set to `"ols"`).
+
+ Arguments:
+ fig: the output of a `plotly_express` charting call
+ Returns:
+ A `pandas.DataFrame` with a column "px_fit_results" containing the `statsmodels`
+ results objects, along with columns identifying the subset of the data the
+ trendline was fit on.
+ """
+ return fig._px_trendlines
+
+
Mapping = namedtuple(
"Mapping",
["show_in_trace_name", "grouper", "val_map", "sequence", "updater", "variable"],
@@ -128,6 +143,7 @@ def make_trace_kwargs(
if "line_close" in args and args["line_close"]:
g = g.append(g.iloc[0])
result = trace_spec.trace_patch.copy() or {}
+ fit_results = None
hover_header = ""
for k in trace_spec.attrs:
v = args[k]
@@ -185,16 +201,18 @@ def make_trace_kwargs(
result["y"] = trendline[:, 1]
hover_header = "LOWESS trendline
"
elif v == "ols":
- fitted = sm.OLS(y, sm.add_constant(x)).fit()
- result["y"] = fitted.predict()
+ fit_results = sm.OLS(y, sm.add_constant(x)).fit()
+ result["y"] = fit_results.predict()
hover_header = "OLS trendline
"
hover_header += "%s = %f * %s + %f
" % (
args["y"],
- fitted.params[1],
+ fit_results.params[1],
args["x"],
- fitted.params[0],
+ fit_results.params[0],
+ )
+ hover_header += (
+ "R2=%f
" % fit_results.rsquared
)
- hover_header += "R2=%f
" % fitted.rsquared
mapping_labels[get_label(args, args["x"])] = "%{x}"
mapping_labels[get_label(args, args["y"])] = "%{y} (trend)"
@@ -253,7 +271,7 @@ def make_trace_kwargs(
if trace_spec.constructor not in [go.Histogram2dContour, go.Parcoords, go.Parcats]:
hover_lines = [k + "=" + v for k, v in mapping_labels.items()]
result["hovertemplate"] = hover_header + "
".join(hover_lines)
- return result
+ return result, fit_results
def configure_axes(args, constructor, fig, axes, orders):
@@ -728,6 +746,7 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
trace_names_by_frame = {}
frames = OrderedDict()
+ trendline_rows = []
for group_name in group_names:
group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0])
mapping_labels = OrderedDict()
@@ -803,17 +822,19 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
):
trace.update(marker=dict(color=trace.line.color))
- trace.update(
- make_trace_kwargs(
- args,
- trace_spec,
- group,
- mapping_labels.copy(),
- sizeref,
- color_range=color_range,
- show_colorbar=(frame_name not in frames),
- )
+ patch, fit_results = make_trace_kwargs(
+ args,
+ trace_spec,
+ group,
+ mapping_labels.copy(),
+ sizeref,
+ color_range=color_range,
+ show_colorbar=(frame_name not in frames),
)
+ trace.update(patch)
+ if fit_results is not None:
+ trendline_rows.append(mapping_labels.copy())
+ trendline_rows[-1]["px_fit_results"] = fit_results
if frame_name not in frames:
frames[frame_name] = dict(data=[], name=frame_name)
frames[frame_name]["data"].append(trace)
@@ -832,6 +853,7 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
layout=layout_patch,
frames=frame_list if len(frames) > 1 else [],
)
+ fig._px_trendlines = pandas.DataFrame(trendline_rows)
axes = {m.variable: m.val_map for m in grouped_mappings}
configure_axes(args, constructor, fig, axes, orders)
configure_animation_controls(args, constructor, fig)