diff --git a/plotly_express/__init__.py b/plotly_express/__init__.py index bb0e36a..6686d45 100644 --- a/plotly_express/__init__.py +++ b/plotly_express/__init__.py @@ -31,7 +31,12 @@ density_contour, ) -from ._core import ExpressFigure, set_mapbox_access_token, defaults # noqa: F401 +from ._core import ( # noqa: F401 + ExpressFigure, + set_mapbox_access_token, + defaults, + get_trendline_results, +) from . import data, colors # noqa: F401 @@ -61,5 +66,6 @@ "data", "colors", "set_mapbox_access_token", + "get_trendline_results", "ExpressFigure", ] diff --git a/plotly_express/_core.py b/plotly_express/_core.py index 20e05c5..0cfeb31 100644 --- a/plotly_express/_core.py +++ b/plotly_express/_core.py @@ -2,7 +2,7 @@ from plotly.offline import init_notebook_mode, iplot from collections import namedtuple, OrderedDict from .colors import qualitative, sequential -import math +import math, pandas class PxDefaults(object): @@ -55,6 +55,21 @@ def _ipython_display_(self): iplot(self, show_link=False, auto_play=False) +def get_trendline_results(fig): + """ + Extracts fit statistics for trendlines (when applied to figures generated with + the `trendline` argument set to `"ols"`). + + Arguments: + fig: the output of a `plotly_express` charting call + Returns: + A `pandas.DataFrame` with a column "px_fit_results" containing the `statsmodels` + results objects, along with columns identifying the subset of the data the + trendline was fit on. + """ + return fig._px_trendlines + + Mapping = namedtuple( "Mapping", ["show_in_trace_name", "grouper", "val_map", "sequence", "updater", "variable"], @@ -128,6 +143,7 @@ def make_trace_kwargs( if "line_close" in args and args["line_close"]: g = g.append(g.iloc[0]) result = trace_spec.trace_patch.copy() or {} + fit_results = None hover_header = "" for k in trace_spec.attrs: v = args[k] @@ -185,16 +201,18 @@ def make_trace_kwargs( result["y"] = trendline[:, 1] hover_header = "LOWESS trendline

" elif v == "ols": - fitted = sm.OLS(y, sm.add_constant(x)).fit() - result["y"] = fitted.predict() + fit_results = sm.OLS(y, sm.add_constant(x)).fit() + result["y"] = fit_results.predict() hover_header = "OLS trendline
" hover_header += "%s = %f * %s + %f
" % ( args["y"], - fitted.params[1], + fit_results.params[1], args["x"], - fitted.params[0], + fit_results.params[0], + ) + hover_header += ( + "R2=%f

" % fit_results.rsquared ) - hover_header += "R2=%f

" % fitted.rsquared mapping_labels[get_label(args, args["x"])] = "%{x}" mapping_labels[get_label(args, args["y"])] = "%{y} (trend)" @@ -253,7 +271,7 @@ def make_trace_kwargs( if trace_spec.constructor not in [go.Histogram2dContour, go.Parcoords, go.Parcats]: hover_lines = [k + "=" + v for k, v in mapping_labels.items()] result["hovertemplate"] = hover_header + "
".join(hover_lines) - return result + return result, fit_results def configure_axes(args, constructor, fig, axes, orders): @@ -728,6 +746,7 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): trace_names_by_frame = {} frames = OrderedDict() + trendline_rows = [] for group_name in group_names: group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0]) mapping_labels = OrderedDict() @@ -803,17 +822,19 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): ): trace.update(marker=dict(color=trace.line.color)) - trace.update( - make_trace_kwargs( - args, - trace_spec, - group, - mapping_labels.copy(), - sizeref, - color_range=color_range, - show_colorbar=(frame_name not in frames), - ) + patch, fit_results = make_trace_kwargs( + args, + trace_spec, + group, + mapping_labels.copy(), + sizeref, + color_range=color_range, + show_colorbar=(frame_name not in frames), ) + trace.update(patch) + if fit_results is not None: + trendline_rows.append(mapping_labels.copy()) + trendline_rows[-1]["px_fit_results"] = fit_results if frame_name not in frames: frames[frame_name] = dict(data=[], name=frame_name) frames[frame_name]["data"].append(trace) @@ -832,6 +853,7 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): layout=layout_patch, frames=frame_list if len(frames) > 1 else [], ) + fig._px_trendlines = pandas.DataFrame(trendline_rows) axes = {m.variable: m.val_map for m in grouped_mappings} configure_axes(args, constructor, fig, axes, orders) configure_animation_controls(args, constructor, fig)