import plotly.graph_objs as go
import plotly.io as pio
from collections import namedtuple, OrderedDict
from ._special_inputs import IdentityMap, Constant, Range
from _plotly_utils.basevalidators import ColorscaleValidator
from plotly.colors import qualitative, sequential
import math
import pandas as pd
import numpy as np
from plotly.subplots import (
make_subplots,
_set_trace_grid_reference,
_subplot_type_for_trace_type,
)
NO_COLOR = "px_no_color_constant"
# Declare all supported attributes, across all plot types
direct_attrables = (
["base", "x", "y", "z", "a", "b", "c", "r", "theta", "size", "x_start", "x_end"]
+ ["hover_name", "text", "names", "values", "parents", "wide_cross"]
+ ["ids", "error_x", "error_x_minus", "error_y", "error_y_minus", "error_z"]
+ ["error_z_minus", "lat", "lon", "locations", "animation_group"]
)
array_attrables = ["dimensions", "custom_data", "hover_data", "path", "wide_variable"]
group_attrables = ["animation_frame", "facet_row", "facet_col", "line_group"]
renameable_group_attrables = [
"color", # renamed to marker.color or line.color in infer_config
"symbol", # renamed to marker.symbol in infer_config
"line_dash", # renamed to line.dash in infer_config
"pattern_shape", # renamed to marker.pattern.shape in infer_config
]
all_attrables = (
direct_attrables + array_attrables + group_attrables + renameable_group_attrables
)
cartesians = [go.Scatter, go.Scattergl, go.Bar, go.Funnel, go.Box, go.Violin]
cartesians += [go.Histogram, go.Histogram2d, go.Histogram2dContour]
class PxDefaults(object):
__slots__ = [
"template",
"width",
"height",
"color_discrete_sequence",
"color_discrete_map",
"color_continuous_scale",
"symbol_sequence",
"symbol_map",
"line_dash_sequence",
"line_dash_map",
"pattern_shape_sequence",
"pattern_shape_map",
"size_max",
"category_orders",
"labels",
]
def __init__(self):
self.reset()
def reset(self):
self.template = None
self.width = None
self.height = None
self.color_discrete_sequence = None
self.color_discrete_map = {}
self.color_continuous_scale = None
self.symbol_sequence = None
self.symbol_map = {}
self.line_dash_sequence = None
self.line_dash_map = {}
self.pattern_shape_sequence = None
self.pattern_shape_map = {}
self.size_max = 20
self.category_orders = {}
self.labels = {}
defaults = PxDefaults()
del PxDefaults
MAPBOX_TOKEN = None
def set_mapbox_access_token(token):
"""
Arguments:
token: A Mapbox token to be used in `plotly.express.scatter_mapbox` and \
`plotly.express.line_mapbox` figures. See \
https://docs.mapbox.com/help/how-mapbox-works/access-tokens/ for more details
"""
global MAPBOX_TOKEN
MAPBOX_TOKEN = token
def get_trendline_results(fig):
"""
Extracts fit statistics for trendlines (when applied to figures generated with
the `trendline` argument set to `"ols"`).
Arguments:
fig: the output of a `plotly.express` charting call
Returns:
A `pandas.DataFrame` with a column "px_fit_results" containing the `statsmodels`
results objects, along with columns identifying the subset of the data the
trendline was fit on.
"""
return fig._px_trendlines
Mapping = namedtuple(
"Mapping",
[
"show_in_trace_name",
"grouper",
"val_map",
"sequence",
"updater",
"variable",
"facet",
],
)
TraceSpec = namedtuple("TraceSpec", ["constructor", "attrs", "trace_patch", "marginal"])
def get_label(args, column):
try:
return args["labels"][column]
except Exception:
return column
def invert_label(args, column):
"""Invert mapping.
Find key corresponding to value column in dict args["labels"].
Returns `column` if the value does not exist.
"""
reversed_labels = {value: key for (key, value) in args["labels"].items()}
try:
return reversed_labels[column]
except Exception:
return column
def _is_continuous(df, col_name):
return df[col_name].dtype.kind in "ifc"
def get_decorated_label(args, column, role):
original_label = label = get_label(args, column)
if "histfunc" in args and (
(role == "z")
or (role == "x" and "orientation" in args and args["orientation"] == "h")
or (role == "y" and "orientation" in args and args["orientation"] == "v")
):
histfunc = args["histfunc"] or "count"
if histfunc != "count":
label = "%s of %s" % (histfunc, label)
else:
label = "count"
if "histnorm" in args and args["histnorm"] is not None:
if label == "count":
label = args["histnorm"]
else:
histnorm = args["histnorm"]
if histfunc == "sum":
if histnorm == "probability":
label = "%s of %s" % ("fraction", label)
elif histnorm == "percent":
label = "%s of %s" % (histnorm, label)
else:
label = "%s weighted by %s" % (histnorm, original_label)
elif histnorm == "probability":
label = "%s of sum of %s" % ("fraction", label)
elif histnorm == "percent":
label = "%s of sum of %s" % ("percent", label)
else:
label = "%s of %s" % (histnorm, label)
if "barnorm" in args and args["barnorm"] is not None:
label = "%s (normalized as %s)" % (label, args["barnorm"])
return label
def make_mapping(args, variable):
if variable == "line_group" or variable == "animation_frame":
return Mapping(
show_in_trace_name=False,
grouper=args[variable],
val_map={},
sequence=[""],
variable=variable,
updater=(lambda trace, v: v),
facet=None,
)
if variable == "facet_row" or variable == "facet_col":
letter = "x" if variable == "facet_col" else "y"
return Mapping(
show_in_trace_name=False,
variable=letter,
grouper=args[variable],
val_map={},
sequence=[i for i in range(1, 1000)],
updater=(lambda trace, v: v),
facet="row" if variable == "facet_row" else "col",
)
(parent, variable, *other_variables) = variable.split(".")
vprefix = variable
arg_name = variable
if variable == "color":
vprefix = "color_discrete"
if variable == "dash":
arg_name = "line_dash"
vprefix = "line_dash"
if variable == "pattern":
arg_name = "pattern_shape"
vprefix = "pattern_shape"
if args[vprefix + "_map"] == "identity":
val_map = IdentityMap()
else:
val_map = args[vprefix + "_map"].copy()
return Mapping(
show_in_trace_name=True,
variable=variable,
grouper=args[arg_name],
val_map=val_map,
sequence=args[vprefix + "_sequence"],
updater=lambda trace, v: trace.update(
{parent: {".".join([variable] + other_variables): v}}
),
facet=None,
)
def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
"""Populates a dict with arguments to update trace
Parameters
----------
args : dict
args to be used for the trace
trace_spec : NamedTuple
which kind of trace to be used (has constructor, marginal etc.
attributes)
trace_data : pandas DataFrame
data
mapping_labels : dict
to be used for hovertemplate
sizeref : float
marker sizeref
Returns
-------
trace_patch : dict
dict to be used to update trace
fit_results : dict
fit information to be used for trendlines
"""
if "line_close" in args and args["line_close"]:
trace_data = trace_data.append(trace_data.iloc[0])
trace_patch = trace_spec.trace_patch.copy() or {}
fit_results = None
hover_header = ""
for attr_name in trace_spec.attrs:
attr_value = args[attr_name]
attr_label = get_decorated_label(args, attr_value, attr_name)
if attr_name == "dimensions":
dims = [
(name, column)
for (name, column) in trace_data.iteritems()
if ((not attr_value) or (name in attr_value))
and (
trace_spec.constructor != go.Parcoords
or _is_continuous(args["data_frame"], name)
)
and (
trace_spec.constructor != go.Parcats
or (attr_value is not None and name in attr_value)
or len(args["data_frame"][name].unique())
<= args["dimensions_max_cardinality"]
)
]
trace_patch["dimensions"] = [
dict(label=get_label(args, name), values=column)
for (name, column) in dims
]
if trace_spec.constructor == go.Splom:
for d in trace_patch["dimensions"]:
d["axis"] = dict(matches=True)
mapping_labels["%{xaxis.title.text}"] = "%{x}"
mapping_labels["%{yaxis.title.text}"] = "%{y}"
elif attr_value is not None:
if attr_name == "size":
if "marker" not in trace_patch:
trace_patch["marker"] = dict()
trace_patch["marker"]["size"] = trace_data[attr_value]
trace_patch["marker"]["sizemode"] = "area"
trace_patch["marker"]["sizeref"] = sizeref
mapping_labels[attr_label] = "%{marker.size}"
elif attr_name == "marginal_x":
if trace_spec.constructor == go.Histogram:
mapping_labels["count"] = "%{y}"
elif attr_name == "marginal_y":
if trace_spec.constructor == go.Histogram:
mapping_labels["count"] = "%{x}"
elif attr_name == "trendline":
if (
attr_value in ["ols", "lowess"]
and args["x"]
and args["y"]
and len(trace_data[[args["x"], args["y"]]].dropna()) > 1
):
import statsmodels.api as sm
# sorting is bad but trace_specs with "trendline" have no other attrs
sorted_trace_data = trace_data.sort_values(by=args["x"])
y = sorted_trace_data[args["y"]].values
x = sorted_trace_data[args["x"]].values
if x.dtype.type == np.datetime64:
x = x.astype(int) / 10 ** 9 # convert to unix epoch seconds
elif x.dtype.type == np.object_:
try:
x = x.astype(np.float64)
except ValueError:
raise ValueError(
"Could not convert value of 'x' ('%s') into a numeric type. "
"If 'x' contains stringified dates, please convert to a datetime column."
% args["x"]
)
if y.dtype.type == np.object_:
try:
y = y.astype(np.float64)
except ValueError:
raise ValueError(
"Could not convert value of 'y' into a numeric type."
)
# preserve original values of "x" in case they're dates
trace_patch["x"] = sorted_trace_data[args["x"]][
np.logical_not(np.logical_or(np.isnan(y), np.isnan(x)))
]
if attr_value == "lowess":
# missing ='drop' is the default value for lowess but not for OLS (None)
# we force it here in case statsmodels change their defaults
trendline = sm.nonparametric.lowess(y, x, missing="drop")
trace_patch["y"] = trendline[:, 1]
hover_header = "LOWESS trendline
"
elif attr_value == "ols":
fit_results = sm.OLS(
y, sm.add_constant(x), missing="drop"
).fit()
trace_patch["y"] = fit_results.predict()
hover_header = "OLS trendline
"
if len(fit_results.params) == 2:
hover_header += "%s = %g * %s + %g
" % (
args["y"],
fit_results.params[1],
args["x"],
fit_results.params[0],
)
else:
hover_header += "%s = %g
" % (
args["y"],
fit_results.params[0],
)
hover_header += (
"R2=%f
" % fit_results.rsquared
)
mapping_labels[get_label(args, args["x"])] = "%{x}"
mapping_labels[get_label(args, args["y"])] = "%{y} (trend)"
elif attr_name.startswith("error"):
error_xy = attr_name[:7]
arr = "arrayminus" if attr_name.endswith("minus") else "array"
if error_xy not in trace_patch:
trace_patch[error_xy] = {}
trace_patch[error_xy][arr] = trace_data[attr_value]
elif attr_name == "custom_data":
if len(attr_value) > 0:
# here we store a data frame in customdata, and it's serialized
# as a list of row lists, which is what we want
trace_patch["customdata"] = trace_data[attr_value]
elif attr_name == "hover_name":
if trace_spec.constructor not in [
go.Histogram,
go.Histogram2d,
go.Histogram2dContour,
]:
trace_patch["hovertext"] = trace_data[attr_value]
if hover_header == "":
hover_header = "%{hovertext}
"
elif attr_name == "hover_data":
if trace_spec.constructor not in [
go.Histogram,
go.Histogram2d,
go.Histogram2dContour,
]:
hover_is_dict = isinstance(attr_value, dict)
customdata_cols = args.get("custom_data") or []
for col in attr_value:
if hover_is_dict and not attr_value[col]:
continue
if col in [
args.get("x"),
args.get("y"),
args.get("z"),
args.get("base"),
]:
continue
try:
position = args["custom_data"].index(col)
except (ValueError, AttributeError, KeyError):
position = len(customdata_cols)
customdata_cols.append(col)
attr_label_col = get_decorated_label(args, col, None)
mapping_labels[attr_label_col] = "%%{customdata[%d]}" % (
position
)
if len(customdata_cols) > 0:
# here we store a data frame in customdata, and it's serialized
# as a list of row lists, which is what we want
trace_patch["customdata"] = trace_data[customdata_cols]
elif attr_name == "color":
if trace_spec.constructor in [go.Choropleth, go.Choroplethmapbox]:
trace_patch["z"] = trace_data[attr_value]
trace_patch["coloraxis"] = "coloraxis1"
mapping_labels[attr_label] = "%{z}"
elif trace_spec.constructor in [
go.Sunburst,
go.Treemap,
go.Icicle,
go.Pie,
go.Funnelarea,
]:
if "marker" not in trace_patch:
trace_patch["marker"] = dict()
if args.get("color_is_continuous"):
trace_patch["marker"]["colors"] = trace_data[attr_value]
trace_patch["marker"]["coloraxis"] = "coloraxis1"
mapping_labels[attr_label] = "%{color}"
else:
trace_patch["marker"]["colors"] = []
if args["color_discrete_map"] is not None:
mapping = args["color_discrete_map"].copy()
else:
mapping = {}
for cat in trace_data[attr_value]:
if mapping.get(cat) is None:
mapping[cat] = args["color_discrete_sequence"][
len(mapping) % len(args["color_discrete_sequence"])
]
trace_patch["marker"]["colors"].append(mapping[cat])
else:
colorable = "marker"
if trace_spec.constructor in [go.Parcats, go.Parcoords]:
colorable = "line"
if colorable not in trace_patch:
trace_patch[colorable] = dict()
trace_patch[colorable]["color"] = trace_data[attr_value]
trace_patch[colorable]["coloraxis"] = "coloraxis1"
mapping_labels[attr_label] = "%%{%s.color}" % colorable
elif attr_name == "animation_group":
trace_patch["ids"] = trace_data[attr_value]
elif attr_name == "locations":
trace_patch[attr_name] = trace_data[attr_value]
mapping_labels[attr_label] = "%{location}"
elif attr_name == "values":
trace_patch[attr_name] = trace_data[attr_value]
_label = "value" if attr_label == "values" else attr_label
mapping_labels[_label] = "%{value}"
elif attr_name == "parents":
trace_patch[attr_name] = trace_data[attr_value]
_label = "parent" if attr_label == "parents" else attr_label
mapping_labels[_label] = "%{parent}"
elif attr_name == "ids":
trace_patch[attr_name] = trace_data[attr_value]
_label = "id" if attr_label == "ids" else attr_label
mapping_labels[_label] = "%{id}"
elif attr_name == "names":
if trace_spec.constructor in [
go.Sunburst,
go.Treemap,
go.Icicle,
go.Pie,
go.Funnelarea,
]:
trace_patch["labels"] = trace_data[attr_value]
_label = "label" if attr_label == "names" else attr_label
mapping_labels[_label] = "%{label}"
else:
trace_patch[attr_name] = trace_data[attr_value]
else:
trace_patch[attr_name] = trace_data[attr_value]
mapping_labels[attr_label] = "%%{%s}" % attr_name
elif (trace_spec.constructor == go.Histogram and attr_name in ["x", "y"]) or (
trace_spec.constructor in [go.Histogram2d, go.Histogram2dContour]
and attr_name == "z"
):
# ensure that stuff like "count" gets into the hoverlabel
mapping_labels[attr_label] = "%%{%s}" % attr_name
if trace_spec.constructor not in [go.Parcoords, go.Parcats]:
# Modify mapping_labels according to hover_data keys
# if hover_data is a dict
mapping_labels_copy = OrderedDict(mapping_labels)
if args["hover_data"] and isinstance(args["hover_data"], dict):
for k, v in mapping_labels.items():
# We need to invert the mapping here
k_args = invert_label(args, k)
if k_args in args["hover_data"]:
formatter = args["hover_data"][k_args][0]
if formatter:
if isinstance(formatter, str):
mapping_labels_copy[k] = v.replace("}", "%s}" % formatter)
else:
_ = mapping_labels_copy.pop(k)
hover_lines = [k + "=" + v for k, v in mapping_labels_copy.items()]
trace_patch["hovertemplate"] = hover_header + "
".join(hover_lines)
trace_patch["hovertemplate"] += ""
return trace_patch, fit_results
def configure_axes(args, constructor, fig, orders):
configurators = {
go.Scatter3d: configure_3d_axes,
go.Scatterternary: configure_ternary_axes,
go.Scatterpolar: configure_polar_axes,
go.Scatterpolargl: configure_polar_axes,
go.Barpolar: configure_polar_axes,
go.Scattermapbox: configure_mapbox,
go.Choroplethmapbox: configure_mapbox,
go.Densitymapbox: configure_mapbox,
go.Scattergeo: configure_geo,
go.Choropleth: configure_geo,
}
for c in cartesians:
configurators[c] = configure_cartesian_axes
if constructor in configurators:
configurators[constructor](args, fig, orders)
def set_cartesian_axis_opts(args, axis, letter, orders):
log_key = "log_" + letter
range_key = "range_" + letter
if log_key in args and args[log_key]:
axis["type"] = "log"
if range_key in args and args[range_key]:
axis["range"] = [math.log(r, 10) for r in args[range_key]]
elif range_key in args and args[range_key]:
axis["range"] = args[range_key]
if args[letter] in orders:
axis["categoryorder"] = "array"
axis["categoryarray"] = (
orders[args[letter]]
if isinstance(axis, go.layout.XAxis)
else list(reversed(orders[args[letter]])) # top down for Y axis
)
def configure_cartesian_marginal_axes(args, fig, orders):
if "histogram" in [args["marginal_x"], args["marginal_y"]]:
fig.layout["barmode"] = "overlay"
nrows = len(fig._grid_ref)
ncols = len(fig._grid_ref[0])
# Set y-axis titles and axis options in the left-most column
for yaxis in fig.select_yaxes(col=1):
set_cartesian_axis_opts(args, yaxis, "y", orders)
# Set x-axis titles and axis options in the bottom-most row
for xaxis in fig.select_xaxes(row=1):
set_cartesian_axis_opts(args, xaxis, "x", orders)
# Configure axis ticks on marginal subplots
if args["marginal_x"]:
fig.update_yaxes(
showticklabels=False, showline=False, ticks="", range=None, row=nrows
)
if args["template"].layout.yaxis.showgrid is None:
fig.update_yaxes(showgrid=args["marginal_x"] == "histogram", row=nrows)
if args["template"].layout.xaxis.showgrid is None:
fig.update_xaxes(showgrid=True, row=nrows)
if args["marginal_y"]:
fig.update_xaxes(
showticklabels=False, showline=False, ticks="", range=None, col=ncols
)
if args["template"].layout.xaxis.showgrid is None:
fig.update_xaxes(showgrid=args["marginal_y"] == "histogram", col=ncols)
if args["template"].layout.yaxis.showgrid is None:
fig.update_yaxes(showgrid=True, col=ncols)
# Add axis titles to non-marginal subplots
y_title = get_decorated_label(args, args["y"], "y")
if args["marginal_x"]:
fig.update_yaxes(title_text=y_title, row=1, col=1)
else:
for row in range(1, nrows + 1):
fig.update_yaxes(title_text=y_title, row=row, col=1)
x_title = get_decorated_label(args, args["x"], "x")
if args["marginal_y"]:
fig.update_xaxes(title_text=x_title, row=1, col=1)
else:
for col in range(1, ncols + 1):
fig.update_xaxes(title_text=x_title, row=1, col=col)
# Configure axis type across all x-axes
if "log_x" in args and args["log_x"]:
fig.update_xaxes(type="log")
# Configure axis type across all y-axes
if "log_y" in args and args["log_y"]:
fig.update_yaxes(type="log")
# Configure matching and axis type for marginal y-axes
matches_y = "y" + str(ncols + 1)
if args["marginal_x"]:
for row in range(2, nrows + 1, 2):
fig.update_yaxes(matches=matches_y, type=None, row=row)
if args["marginal_y"]:
for col in range(2, ncols + 1, 2):
fig.update_xaxes(matches="x2", type=None, col=col)
def configure_cartesian_axes(args, fig, orders):
if ("marginal_x" in args and args["marginal_x"]) or (
"marginal_y" in args and args["marginal_y"]
):
configure_cartesian_marginal_axes(args, fig, orders)
return
# Set y-axis titles and axis options in the left-most column
y_title = get_decorated_label(args, args["y"], "y")
for yaxis in fig.select_yaxes(col=1):
yaxis.update(title_text=y_title)
set_cartesian_axis_opts(args, yaxis, "y", orders)
# Set x-axis titles and axis options in the bottom-most row
x_title = get_decorated_label(args, args["x"], "x")
for xaxis in fig.select_xaxes(row=1):
if "is_timeline" not in args:
xaxis.update(title_text=x_title)
set_cartesian_axis_opts(args, xaxis, "x", orders)
# Configure axis type across all x-axes
if "log_x" in args and args["log_x"]:
fig.update_xaxes(type="log")
# Configure axis type across all y-axes
if "log_y" in args and args["log_y"]:
fig.update_yaxes(type="log")
if "is_timeline" in args:
fig.update_xaxes(type="date")
def configure_ternary_axes(args, fig, orders):
fig.update_ternaries(
aaxis=dict(title_text=get_label(args, args["a"])),
baxis=dict(title_text=get_label(args, args["b"])),
caxis=dict(title_text=get_label(args, args["c"])),
)
def configure_polar_axes(args, fig, orders):
patch = dict(
angularaxis=dict(direction=args["direction"], rotation=args["start_angle"]),
radialaxis=dict(),
)
for var, axis in [("r", "radialaxis"), ("theta", "angularaxis")]:
if args[var] in orders:
patch[axis]["categoryorder"] = "array"
patch[axis]["categoryarray"] = orders[args[var]]
radialaxis = patch["radialaxis"]
if args["log_r"]:
radialaxis["type"] = "log"
if args["range_r"]:
radialaxis["range"] = [math.log(x, 10) for x in args["range_r"]]
else:
if args["range_r"]:
radialaxis["range"] = args["range_r"]
if args["range_theta"]:
patch["sector"] = args["range_theta"]
fig.update_polars(patch)
def configure_3d_axes(args, fig, orders):
patch = dict(
xaxis=dict(title_text=get_label(args, args["x"])),
yaxis=dict(title_text=get_label(args, args["y"])),
zaxis=dict(title_text=get_label(args, args["z"])),
)
for letter in ["x", "y", "z"]:
axis = patch[letter + "axis"]
if args["log_" + letter]:
axis["type"] = "log"
if args["range_" + letter]:
axis["range"] = [math.log(x, 10) for x in args["range_" + letter]]
else:
if args["range_" + letter]:
axis["range"] = args["range_" + letter]
if args[letter] in orders:
axis["categoryorder"] = "array"
axis["categoryarray"] = orders[args[letter]]
fig.update_scenes(patch)
def configure_mapbox(args, fig, orders):
center = args["center"]
if not center and "lat" in args and "lon" in args:
center = dict(
lat=args["data_frame"][args["lat"]].mean(),
lon=args["data_frame"][args["lon"]].mean(),
)
fig.update_mapboxes(
accesstoken=MAPBOX_TOKEN,
center=center,
zoom=args["zoom"],
style=args["mapbox_style"],
)
def configure_geo(args, fig, orders):
fig.update_geos(
center=args["center"],
scope=args["scope"],
fitbounds=args["fitbounds"],
visible=args["basemap_visible"],
projection=dict(type=args["projection"]),
)
def configure_animation_controls(args, constructor, fig):
def frame_args(duration):
return {
"frame": {"duration": duration, "redraw": constructor != go.Scatter},
"mode": "immediate",
"fromcurrent": True,
"transition": {"duration": duration, "easing": "linear"},
}
if "animation_frame" in args and args["animation_frame"] and len(fig.frames) > 1:
fig.layout.updatemenus = [
{
"buttons": [
{
"args": [None, frame_args(500)],
"label": "▶",
"method": "animate",
},
{
"args": [[None], frame_args(0)],
"label": "◼",
"method": "animate",
},
],
"direction": "left",
"pad": {"r": 10, "t": 70},
"showactive": False,
"type": "buttons",
"x": 0.1,
"xanchor": "right",
"y": 0,
"yanchor": "top",
}
]
fig.layout.sliders = [
{
"active": 0,
"yanchor": "top",
"xanchor": "left",
"currentvalue": {
"prefix": get_label(args, args["animation_frame"]) + "="
},
"pad": {"b": 10, "t": 60},
"len": 0.9,
"x": 0.1,
"y": 0,
"steps": [
{
"args": [[f.name], frame_args(0)],
"label": f.name,
"method": "animate",
}
for f in fig.frames
],
}
]
def make_trace_spec(args, constructor, attrs, trace_patch):
if constructor in [go.Scatter, go.Scatterpolar]:
if "render_mode" in args and (
args["render_mode"] == "webgl"
or (
args["render_mode"] == "auto"
and len(args["data_frame"]) > 1000
and args["animation_frame"] is None
)
):
if constructor == go.Scatter:
constructor = go.Scattergl
if "orientation" in trace_patch:
del trace_patch["orientation"]
else:
constructor = go.Scatterpolargl
# Create base trace specification
result = [TraceSpec(constructor, attrs, trace_patch, None)]
# Add marginal trace specifications
for letter in ["x", "y"]:
if "marginal_" + letter in args and args["marginal_" + letter]:
trace_spec = None
axis_map = dict(
xaxis="x1" if letter == "x" else "x2",
yaxis="y1" if letter == "y" else "y2",
)
if args["marginal_" + letter] == "histogram":
trace_spec = TraceSpec(
constructor=go.Histogram,
attrs=[letter, "marginal_" + letter],
trace_patch=dict(opacity=0.5, bingroup=letter, **axis_map),
marginal=letter,
)
elif args["marginal_" + letter] == "violin":
trace_spec = TraceSpec(
constructor=go.Violin,
attrs=[letter, "hover_name", "hover_data"],
trace_patch=dict(scalegroup=letter),
marginal=letter,
)
elif args["marginal_" + letter] == "box":
trace_spec = TraceSpec(
constructor=go.Box,
attrs=[letter, "hover_name", "hover_data"],
trace_patch=dict(notched=True),
marginal=letter,
)
elif args["marginal_" + letter] == "rug":
symbols = {"x": "line-ns-open", "y": "line-ew-open"}
trace_spec = TraceSpec(
constructor=go.Box,
attrs=[letter, "hover_name", "hover_data"],
trace_patch=dict(
fillcolor="rgba(255,255,255,0)",
line={"color": "rgba(255,255,255,0)"},
boxpoints="all",
jitter=0,
hoveron="points",
marker={"symbol": symbols[letter]},
),
marginal=letter,
)
if "color" in attrs or "color" not in args:
if "marker" not in trace_spec.trace_patch:
trace_spec.trace_patch["marker"] = dict()
first_default_color = args["color_continuous_scale"][0]
trace_spec.trace_patch["marker"]["color"] = first_default_color
result.append(trace_spec)
# Add trendline trace specifications
if "trendline" in args and args["trendline"]:
trace_spec = TraceSpec(
constructor=go.Scattergl if constructor == go.Scattergl else go.Scatter,
attrs=["trendline"],
trace_patch=dict(mode="lines"),
marginal=None,
)
if args["trendline_color_override"]:
trace_spec.trace_patch["line"] = dict(
color=args["trendline_color_override"]
)
result.append(trace_spec)
return result
def one_group(x):
return ""
def apply_default_cascade(args):
# first we apply px.defaults to unspecified args
for param in defaults.__slots__:
if param in args and args[param] is None:
args[param] = getattr(defaults, param)
# load the default template if set, otherwise "plotly"
if args["template"] is None:
if pio.templates.default is not None:
args["template"] = pio.templates.default
else:
args["template"] = "plotly"
try:
# retrieve the actual template if we were given a name
args["template"] = pio.templates[args["template"]]
except Exception:
# otherwise try to build a real template
args["template"] = go.layout.Template(args["template"])
# if colors not set explicitly or in px.defaults, defer to a template
# if the template doesn't have one, we set some final fallback defaults
if "color_continuous_scale" in args:
if (
args["color_continuous_scale"] is None
and args["template"].layout.colorscale.sequential
):
args["color_continuous_scale"] = [
x[1] for x in args["template"].layout.colorscale.sequential
]
if args["color_continuous_scale"] is None:
args["color_continuous_scale"] = sequential.Viridis
if "color_discrete_sequence" in args:
if args["color_discrete_sequence"] is None and args["template"].layout.colorway:
args["color_discrete_sequence"] = args["template"].layout.colorway
if args["color_discrete_sequence"] is None:
args["color_discrete_sequence"] = qualitative.D3
# if symbol_sequence/line_dash_sequence not set explicitly or in px.defaults,
# see if we can defer to template. If not, set reasonable defaults
if "symbol_sequence" in args:
if args["symbol_sequence"] is None and args["template"].data.scatter:
args["symbol_sequence"] = [
scatter.marker.symbol for scatter in args["template"].data.scatter
]
if not args["symbol_sequence"] or not any(args["symbol_sequence"]):
args["symbol_sequence"] = ["circle", "diamond", "square", "x", "cross"]
if "line_dash_sequence" in args:
if args["line_dash_sequence"] is None and args["template"].data.scatter:
args["line_dash_sequence"] = [
scatter.line.dash for scatter in args["template"].data.scatter
]
if not args["line_dash_sequence"] or not any(args["line_dash_sequence"]):
args["line_dash_sequence"] = [
"solid",
"dot",
"dash",
"longdash",
"dashdot",
"longdashdot",
]
if "pattern_shape_sequence" in args:
if args["pattern_shape_sequence"] is None and args["template"].data.bar:
args["pattern_shape_sequence"] = [
bar.marker.pattern.shape for bar in args["template"].data.bar
]
if not args["pattern_shape_sequence"] or not any(
args["pattern_shape_sequence"]
):
args["pattern_shape_sequence"] = ["", "/", "\\", "x", "+", "."]
def _check_name_not_reserved(field_name, reserved_names):
if field_name not in reserved_names:
return field_name
else:
raise NameError(
"A name conflict was encountered for argument '%s'. "
"A column or index with name '%s' is ambiguous." % (field_name, field_name)
)
def _get_reserved_col_names(args):
"""
This function builds a list of columns of the data_frame argument used
as arguments, either as str/int arguments or given as columns
(pandas series type).
"""
df = args["data_frame"]
reserved_names = set()
for field in args:
if field not in all_attrables:
continue
names = args[field] if field in array_attrables else [args[field]]
if names is None:
continue
for arg in names:
if arg is None:
continue
elif isinstance(arg, str): # no need to add ints since kw arg are not ints
reserved_names.add(arg)
elif isinstance(arg, pd.Series):
arg_name = arg.name
if arg_name and hasattr(df, arg_name):
in_df = arg is df[arg_name]
if in_df:
reserved_names.add(arg_name)
elif arg is df.index and arg.name is not None:
reserved_names.add(arg.name)
return reserved_names
def _is_col_list(df_input, arg):
"""Returns True if arg looks like it's a list of columns or references to columns
in df_input, and False otherwise (in which case it's assumed to be a single column
or reference to a column).
"""
if arg is None or isinstance(arg, str) or isinstance(arg, int):
return False
if isinstance(arg, pd.MultiIndex):
return False # just to keep existing behaviour for now
try:
iter(arg)
except TypeError:
return False # not iterable
for c in arg:
if isinstance(c, str) or isinstance(c, int):
if df_input is None or c not in df_input.columns:
return False
else:
try:
iter(c)
except TypeError:
return False # not iterable
return True
def _isinstance_listlike(x):
"""Returns True if x is an iterable which can be transformed into a pandas Series,
False for the other types of possible values of a `hover_data` dict.
A tuple of length 2 is a special case corresponding to a (format, data) tuple.
"""
if (
isinstance(x, str)
or (isinstance(x, tuple) and len(x) == 2)
or isinstance(x, bool)
or x is None
):
return False
else:
return True
def _escape_col_name(df_input, col_name, extra):
while df_input is not None and (col_name in df_input.columns or col_name in extra):
col_name = "_" + col_name
return col_name
def to_unindexed_series(x):
"""
assuming x is list-like or even an existing pd.Series, return a new pd.Series with
no index, without extracting the data from an existing Series via numpy, which
seems to mangle datetime columns. Stripping the index from existing pd.Series is
required to get things to match up right in the new DataFrame we're building
"""
return pd.Series(x).reset_index(drop=True)
def process_args_into_dataframe(args, wide_mode, var_name, value_name):
"""
After this function runs, the `all_attrables` keys of `args` all contain only
references to columns of `df_output`. This function handles the extraction of data
from `args["attrable"]` and column-name-generation as appropriate, and adds the
data to `df_output` and then replaces `args["attrable"]` with the appropriate
reference.
"""
df_input = args["data_frame"]
df_provided = df_input is not None
df_output = pd.DataFrame()
constants = dict()
ranges = list()
wide_id_vars = set()
reserved_names = _get_reserved_col_names(args) if df_provided else set()
# Case of functions with a "dimensions" kw: scatter_matrix, parcats, parcoords
if "dimensions" in args and args["dimensions"] is None:
if not df_provided:
raise ValueError(
"No data were provided. Please provide data either with the `data_frame` or with the `dimensions` argument."
)
else:
df_output[df_input.columns] = df_input[df_input.columns]
# hover_data is a dict
hover_data_is_dict = (
"hover_data" in args
and args["hover_data"]
and isinstance(args["hover_data"], dict)
)
# If dict, convert all values of hover_data to tuples to simplify processing
if hover_data_is_dict:
for k in args["hover_data"]:
if _isinstance_listlike(args["hover_data"][k]):
args["hover_data"][k] = (True, args["hover_data"][k])
if not isinstance(args["hover_data"][k], tuple):
args["hover_data"][k] = (args["hover_data"][k], None)
if df_provided and args["hover_data"][k][1] is not None and k in df_input:
raise ValueError(
"Ambiguous input: values for '%s' appear both in hover_data and data_frame"
% k
)
# Loop over possible arguments
for field_name in all_attrables:
# Massaging variables
argument_list = (
[args.get(field_name)]
if field_name not in array_attrables
else args.get(field_name)
)
# argument not specified, continue
if argument_list is None or argument_list is [None]:
continue
# Argument name: field_name if the argument is not a list
# Else we give names like ["hover_data_0, hover_data_1"] etc.
field_list = (
[field_name]
if field_name not in array_attrables
else [field_name + "_" + str(i) for i in range(len(argument_list))]
)
# argument_list and field_list ready, iterate over them
# Core of the loop starts here
for i, (argument, field) in enumerate(zip(argument_list, field_list)):
length = len(df_output)
if argument is None:
continue
col_name = None
# Case of multiindex
if isinstance(argument, pd.MultiIndex):
raise TypeError(
"Argument '%s' is a pandas MultiIndex. "
"pandas MultiIndex is not supported by plotly express "
"at the moment." % field
)
# ----------------- argument is a special value ----------------------
if isinstance(argument, Constant) or isinstance(argument, Range):
col_name = _check_name_not_reserved(
str(argument.label) if argument.label is not None else field,
reserved_names,
)
if isinstance(argument, Constant):
constants[col_name] = argument.value
else:
ranges.append(col_name)
# ----------------- argument is likely a col name ----------------------
elif isinstance(argument, str) or not hasattr(argument, "__len__"):
if (
field_name == "hover_data"
and hover_data_is_dict
and args["hover_data"][str(argument)][1] is not None
):
# hover_data has onboard data
# previously-checked to have no name-conflict with data_frame
col_name = str(argument)
real_argument = args["hover_data"][col_name][1]
if length and len(real_argument) != length:
raise ValueError(
"All arguments should have the same length. "
"The length of hover_data key `%s` is %d, whereas the "
"length of previously-processed arguments %s is %d"
% (
argument,
len(real_argument),
str(list(df_output.columns)),
length,
)
)
df_output[col_name] = to_unindexed_series(real_argument)
elif not df_provided:
raise ValueError(
"String or int arguments are only possible when a "
"DataFrame or an array is provided in the `data_frame` "
"argument. No DataFrame was provided, but argument "
"'%s' is of type str or int." % field
)
# Check validity of column name
elif argument not in df_input.columns:
if wide_mode and argument in (value_name, var_name):
continue
else:
err_msg = (
"Value of '%s' is not the name of a column in 'data_frame'. "
"Expected one of %s but received: %s"
% (field, str(list(df_input.columns)), argument)
)
if argument == "index":
err_msg += "\n To use the index, pass it in directly as `df.index`."
raise ValueError(err_msg)
elif length and len(df_input[argument]) != length:
raise ValueError(
"All arguments should have the same length. "
"The length of column argument `df[%s]` is %d, whereas the "
"length of previously-processed arguments %s is %d"
% (
field,
len(df_input[argument]),
str(list(df_output.columns)),
length,
)
)
else:
col_name = str(argument)
df_output[col_name] = to_unindexed_series(df_input[argument])
# ----------------- argument is likely a column / array / list.... -------
else:
if df_provided and hasattr(argument, "name"):
if argument is df_input.index:
if argument.name is None or argument.name in df_input:
col_name = "index"
else:
col_name = argument.name
col_name = _escape_col_name(
df_input, col_name, [var_name, value_name]
)
else:
if (
argument.name is not None
and argument.name in df_input
and argument is df_input[argument.name]
):
col_name = argument.name
if col_name is None: # numpy array, list...
col_name = _check_name_not_reserved(field, reserved_names)
if length and len(argument) != length:
raise ValueError(
"All arguments should have the same length. "
"The length of argument `%s` is %d, whereas the "
"length of previously-processed arguments %s is %d"
% (field, len(argument), str(list(df_output.columns)), length)
)
df_output[str(col_name)] = to_unindexed_series(argument)
# Finally, update argument with column name now that column exists
assert col_name is not None, (
"Data-frame processing failure, likely due to a internal bug. "
"Please report this to "
"https://github.com/plotly/plotly.py/issues/new and we will try to "
"replicate and fix it."
)
if field_name not in array_attrables:
args[field_name] = str(col_name)
elif isinstance(args[field_name], dict):
pass
else:
args[field_name][i] = str(col_name)
if field_name != "wide_variable":
wide_id_vars.add(str(col_name))
for col_name in ranges:
df_output[col_name] = range(len(df_output))
for col_name in constants:
df_output[col_name] = constants[col_name]
return df_output, wide_id_vars
def build_dataframe(args, constructor):
"""
Constructs a dataframe and modifies `args` in-place.
The argument values in `args` can be either strings corresponding to
existing columns of a dataframe, or data arrays (lists, numpy arrays,
pandas columns, series).
Parameters
----------
args : OrderedDict
arguments passed to the px function and subsequently modified
constructor : graph_object trace class
the trace type selected for this figure
"""
# make copies of all the fields via dict() and list()
for field in args:
if field in array_attrables and args[field] is not None:
args[field] = (
dict(args[field])
if isinstance(args[field], dict)
else list(args[field])
)
# Cast data_frame argument to DataFrame (it could be a numpy array, dict etc.)
df_provided = args["data_frame"] is not None
if df_provided and not isinstance(args["data_frame"], pd.DataFrame):
args["data_frame"] = pd.DataFrame(args["data_frame"])
df_input = args["data_frame"]
# now we handle special cases like wide-mode or x-xor-y specification
# by rearranging args to tee things up for process_args_into_dataframe to work
no_x = args.get("x") is None
no_y = args.get("y") is None
wide_x = False if no_x else _is_col_list(df_input, args["x"])
wide_y = False if no_y else _is_col_list(df_input, args["y"])
wide_mode = False
var_name = None # will likely be "variable" in wide_mode
wide_cross_name = None # will likely be "index" in wide_mode
value_name = None # will likely be "value" in wide_mode
hist2d_types = [go.Histogram2d, go.Histogram2dContour]
if constructor in cartesians:
if wide_x and wide_y:
raise ValueError(
"Cannot accept list of column references or list of columns for both `x` and `y`."
)
if df_provided and no_x and no_y:
wide_mode = True
if isinstance(df_input.columns, pd.MultiIndex):
raise TypeError(
"Data frame columns is a pandas MultiIndex. "
"pandas MultiIndex is not supported by plotly express "
"at the moment."
)
args["wide_variable"] = list(df_input.columns)
var_name = df_input.columns.name
if var_name in [None, "value", "index"] or var_name in df_input:
var_name = "variable"
if constructor == go.Funnel:
wide_orientation = args.get("orientation") or "h"
else:
wide_orientation = args.get("orientation") or "v"
args["orientation"] = wide_orientation
args["wide_cross"] = None
elif wide_x != wide_y:
wide_mode = True
args["wide_variable"] = args["y"] if wide_y else args["x"]
if df_provided and args["wide_variable"] is df_input.columns:
var_name = df_input.columns.name
if isinstance(args["wide_variable"], pd.Index):
args["wide_variable"] = list(args["wide_variable"])
if var_name in [None, "value", "index"] or (
df_provided and var_name in df_input
):
var_name = "variable"
if constructor == go.Histogram:
wide_orientation = "v" if wide_x else "h"
else:
wide_orientation = "v" if wide_y else "h"
args["y" if wide_y else "x"] = None
args["wide_cross"] = None
if not no_x and not no_y:
wide_cross_name = "__x__" if wide_y else "__y__"
if wide_mode:
value_name = _escape_col_name(df_input, "value", [])
var_name = _escape_col_name(df_input, var_name, [])
missing_bar_dim = None
if constructor in [go.Scatter, go.Bar, go.Funnel] + hist2d_types:
if not wide_mode and (no_x != no_y):
for ax in ["x", "y"]:
if args.get(ax) is None:
args[ax] = df_input.index if df_provided else Range()
if constructor == go.Bar:
missing_bar_dim = ax
else:
if args["orientation"] is None:
args["orientation"] = "v" if ax == "x" else "h"
if wide_mode and wide_cross_name is None:
if no_x != no_y and args["orientation"] is None:
args["orientation"] = "v" if no_x else "h"
if df_provided:
if isinstance(df_input.index, pd.MultiIndex):
raise TypeError(
"Data frame index is a pandas MultiIndex. "
"pandas MultiIndex is not supported by plotly express "
"at the moment."
)
args["wide_cross"] = df_input.index
else:
args["wide_cross"] = Range(
label=_escape_col_name(df_input, "index", [var_name, value_name])
)
no_color = False
if type(args.get("color")) == str and args["color"] == NO_COLOR:
no_color = True
args["color"] = None
# now that things have been prepped, we do the systematic rewriting of `args`
df_output, wide_id_vars = process_args_into_dataframe(
args, wide_mode, var_name, value_name
)
# now that `df_output` exists and `args` contains only references, we complete
# the special-case and wide-mode handling by further rewriting args and/or mutating
# df_output
count_name = _escape_col_name(df_output, "count", [var_name, value_name])
if not wide_mode and missing_bar_dim and constructor == go.Bar:
# now that we've populated df_output, we check to see if the non-missing
# dimension is categorical: if so, then setting the missing dimension to a
# constant 1 is a less-insane thing to do than setting it to the index by
# default and we let the normal auto-orientation-code do its thing later
other_dim = "x" if missing_bar_dim == "y" else "y"
if not _is_continuous(df_output, args[other_dim]):
args[missing_bar_dim] = count_name
df_output[count_name] = 1
else:
# on the other hand, if the non-missing dimension is continuous, then we
# can use this information to override the normal auto-orientation code
if args["orientation"] is None:
args["orientation"] = "v" if missing_bar_dim == "x" else "h"
if constructor in hist2d_types:
del args["orientation"]
if wide_mode:
# at this point, `df_output` is semi-long/semi-wide, but we know which columns
# are which, so we melt it and reassign `args` to refer to the newly-tidy
# columns, keeping track of various names and manglings set up above
wide_value_vars = [c for c in args["wide_variable"] if c not in wide_id_vars]
del args["wide_variable"]
if wide_cross_name == "__x__":
wide_cross_name = args["x"]
elif wide_cross_name == "__y__":
wide_cross_name = args["y"]
else:
wide_cross_name = args["wide_cross"]
del args["wide_cross"]
dtype = None
for v in wide_value_vars:
v_dtype = df_output[v].dtype.kind
v_dtype = "number" if v_dtype in ["i", "f", "u"] else v_dtype
if dtype is None:
dtype = v_dtype
elif dtype != v_dtype:
raise ValueError(
"Plotly Express cannot process wide-form data with columns of different type."
)
df_output = df_output.melt(
id_vars=wide_id_vars,
value_vars=wide_value_vars,
var_name=var_name,
value_name=value_name,
)
assert len(df_output.columns) == len(set(df_output.columns)), (
"Wide-mode name-inference failure, likely due to a internal bug. "
"Please report this to "
"https://github.com/plotly/plotly.py/issues/new and we will try to "
"replicate and fix it."
)
df_output[var_name] = df_output[var_name].astype(str)
orient_v = wide_orientation == "v"
if constructor in [go.Scatter, go.Funnel] + hist2d_types:
args["x" if orient_v else "y"] = wide_cross_name
args["y" if orient_v else "x"] = value_name
if constructor != go.Histogram2d:
args["color"] = args["color"] or var_name
if "line_group" in args:
args["line_group"] = args["line_group"] or var_name
if constructor == go.Bar:
if _is_continuous(df_output, value_name):
args["x" if orient_v else "y"] = wide_cross_name
args["y" if orient_v else "x"] = value_name
args["color"] = args["color"] or var_name
else:
args["x" if orient_v else "y"] = value_name
args["y" if orient_v else "x"] = count_name
df_output[count_name] = 1
args["color"] = args["color"] or var_name
if constructor in [go.Violin, go.Box]:
args["x" if orient_v else "y"] = wide_cross_name or var_name
args["y" if orient_v else "x"] = value_name
if constructor == go.Histogram:
args["x" if orient_v else "y"] = value_name
args["y" if orient_v else "x"] = wide_cross_name
args["color"] = args["color"] or var_name
if no_color:
args["color"] = None
args["data_frame"] = df_output
return args
def _check_dataframe_all_leaves(df):
df_sorted = df.sort_values(by=list(df.columns))
null_mask = df_sorted.isnull()
df_sorted = df_sorted.astype(str)
null_indices = np.nonzero(null_mask.any(axis=1).values)[0]
for null_row_index in null_indices:
row = null_mask.iloc[null_row_index]
i = np.nonzero(row.values)[0][0]
if not row[i:].all():
raise ValueError(
"None entries cannot have not-None children",
df_sorted.iloc[null_row_index],
)
df_sorted[null_mask] = ""
row_strings = list(df_sorted.apply(lambda x: "".join(x), axis=1))
for i, row in enumerate(row_strings[:-1]):
if row_strings[i + 1] in row and (i + 1) in null_indices:
raise ValueError(
"Non-leaves rows are not permitted in the dataframe \n",
df_sorted.iloc[i + 1],
"is not a leaf.",
)
def process_dataframe_hierarchy(args):
"""
Build dataframe for sunburst, treemap, or icicle when the path argument is provided.
"""
df = args["data_frame"]
path = args["path"][::-1]
_check_dataframe_all_leaves(df[path[::-1]])
discrete_color = False
new_path = []
for col_name in path:
new_col_name = col_name + "_path_copy"
new_path.append(new_col_name)
df[new_col_name] = df[col_name]
path = new_path
# ------------ Define aggregation functions --------------------------------
def aggfunc_discrete(x):
uniques = x.unique()
if len(uniques) == 1:
return uniques[0]
else:
return "(?)"
agg_f = {}
aggfunc_color = None
if args["values"]:
try:
df[args["values"]] = pd.to_numeric(df[args["values"]])
except ValueError:
raise ValueError(
"Column `%s` of `df` could not be converted to a numerical data type."
% args["values"]
)
if args["color"]:
if args["color"] == args["values"]:
new_value_col_name = args["values"] + "_sum"
df[new_value_col_name] = df[args["values"]]
args["values"] = new_value_col_name
count_colname = args["values"]
else:
# we need a count column for the first groupby and the weighted mean of color
# trick to be sure the col name is unused: take the sum of existing names
count_colname = (
"count"
if "count" not in df.columns
else "".join([str(el) for el in list(df.columns)])
)
# we can modify df because it's a copy of the px argument
df[count_colname] = 1
args["values"] = count_colname
agg_f[count_colname] = "sum"
if args["color"]:
if not _is_continuous(df, args["color"]):
aggfunc_color = aggfunc_discrete
discrete_color = True
else:
def aggfunc_continuous(x):
return np.average(x, weights=df.loc[x.index, count_colname])
aggfunc_color = aggfunc_continuous
agg_f[args["color"]] = aggfunc_color
# Other columns (for color, hover_data, custom_data etc.)
cols = list(set(df.columns).difference(path))
for col in cols: # for hover_data, custom_data etc.
if col not in agg_f:
agg_f[col] = aggfunc_discrete
# Avoid collisions with reserved names - columns in the path have been copied already
cols = list(set(cols) - set(["labels", "parent", "id"]))
# ----------------------------------------------------------------------------
df_all_trees = pd.DataFrame(columns=["labels", "parent", "id"] + cols)
# Set column type here (useful for continuous vs discrete colorscale)
for col in cols:
df_all_trees[col] = df_all_trees[col].astype(df[col].dtype)
for i, level in enumerate(path):
df_tree = pd.DataFrame(columns=df_all_trees.columns)
dfg = df.groupby(path[i:]).agg(agg_f)
dfg = dfg.reset_index()
# Path label massaging
df_tree["labels"] = dfg[level].copy().astype(str)
df_tree["parent"] = ""
df_tree["id"] = dfg[level].copy().astype(str)
if i < len(path) - 1:
j = i + 1
while j < len(path):
df_tree["parent"] = (
dfg[path[j]].copy().astype(str) + "/" + df_tree["parent"]
)
df_tree["id"] = dfg[path[j]].copy().astype(str) + "/" + df_tree["id"]
j += 1
df_tree["parent"] = df_tree["parent"].str.rstrip("/")
if cols:
df_tree[cols] = dfg[cols]
df_all_trees = df_all_trees.append(df_tree, ignore_index=True)
# we want to make sure than (?) is the first color of the sequence
if args["color"] and discrete_color:
sort_col_name = "sort_color_if_discrete_color"
while sort_col_name in df_all_trees.columns:
sort_col_name += "0"
df_all_trees[sort_col_name] = df[args["color"]].astype(str)
df_all_trees = df_all_trees.sort_values(by=sort_col_name)
# Now modify arguments
args["data_frame"] = df_all_trees
args["path"] = None
args["ids"] = "id"
args["names"] = "labels"
args["parents"] = "parent"
if args["color"]:
if not args["hover_data"]:
args["hover_data"] = [args["color"]]
elif isinstance(args["hover_data"], dict):
if not args["hover_data"].get(args["color"]):
args["hover_data"][args["color"]] = (True, None)
else:
args["hover_data"].append(args["color"])
return args
def process_dataframe_timeline(args):
"""
Massage input for bar traces for px.timeline()
"""
args["is_timeline"] = True
if args["x_start"] is None or args["x_end"] is None:
raise ValueError("Both x_start and x_end are required")
try:
x_start = pd.to_datetime(args["data_frame"][args["x_start"]])
x_end = pd.to_datetime(args["data_frame"][args["x_end"]])
except (ValueError, TypeError):
raise TypeError(
"Both x_start and x_end must refer to data convertible to datetimes."
)
# note that we are not adding any columns to the data frame here, so no risk of overwrite
args["data_frame"][args["x_end"]] = (x_end - x_start).astype("timedelta64[ms]")
args["x"] = args["x_end"]
del args["x_end"]
args["base"] = args["x_start"]
del args["x_start"]
return args
def infer_config(args, constructor, trace_patch, layout_patch):
attrs = [k for k in direct_attrables + array_attrables if k in args]
grouped_attrs = []
# Compute sizeref
sizeref = 0
if "size" in args and args["size"]:
sizeref = args["data_frame"][args["size"]].max() / args["size_max"] ** 2
# Compute color attributes and grouping attributes
if "color" in args:
if "color_continuous_scale" in args:
if "color_discrete_sequence" not in args:
attrs.append("color")
else:
if args["color"] and _is_continuous(args["data_frame"], args["color"]):
attrs.append("color")
args["color_is_continuous"] = True
elif constructor in [go.Sunburst, go.Treemap, go.Icicle]:
attrs.append("color")
args["color_is_continuous"] = False
else:
grouped_attrs.append("marker.color")
elif "line_group" in args or constructor == go.Histogram2dContour:
grouped_attrs.append("line.color")
elif constructor in [go.Pie, go.Funnelarea]:
attrs.append("color")
if args["color"]:
if args["hover_data"] is None:
args["hover_data"] = []
args["hover_data"].append(args["color"])
else:
grouped_attrs.append("marker.color")
show_colorbar = bool(
"color" in attrs
and args["color"]
and constructor not in [go.Pie, go.Funnelarea]
and (
constructor not in [go.Treemap, go.Sunburst, go.Icicle]
or args.get("color_is_continuous")
)
)
else:
show_colorbar = False
if "line_dash" in args:
grouped_attrs.append("line.dash")
if "symbol" in args:
grouped_attrs.append("marker.symbol")
if "pattern_shape" in args:
grouped_attrs.append("marker.pattern.shape")
if "orientation" in args:
has_x = args["x"] is not None
has_y = args["y"] is not None
if args["orientation"] is None:
if constructor in [go.Histogram, go.Scatter]:
if has_y and not has_x:
args["orientation"] = "h"
elif constructor in [go.Violin, go.Box, go.Bar, go.Funnel]:
if has_x and not has_y:
args["orientation"] = "h"
if args["orientation"] is None and has_x and has_y:
x_is_continuous = _is_continuous(args["data_frame"], args["x"])
y_is_continuous = _is_continuous(args["data_frame"], args["y"])
if x_is_continuous and not y_is_continuous:
args["orientation"] = "h"
if y_is_continuous and not x_is_continuous:
args["orientation"] = "v"
if args["orientation"] is None:
args["orientation"] = "v"
if constructor == go.Histogram:
if has_x and has_y and args["histfunc"] is None:
args["histfunc"] = trace_patch["histfunc"] = "sum"
orientation = args["orientation"]
nbins = args["nbins"]
trace_patch["nbinsx"] = nbins if orientation == "v" else None
trace_patch["nbinsy"] = None if orientation == "v" else nbins
trace_patch["bingroup"] = "x" if orientation == "v" else "y"
trace_patch["orientation"] = args["orientation"]
if constructor in [go.Violin, go.Box]:
mode = "boxmode" if constructor == go.Box else "violinmode"
if layout_patch[mode] is None and args["color"] is not None:
if args["y"] == args["color"] and args["orientation"] == "h":
layout_patch[mode] = "overlay"
elif args["x"] == args["color"] and args["orientation"] == "v":
layout_patch[mode] = "overlay"
if layout_patch[mode] is None:
layout_patch[mode] = "group"
if (
constructor == go.Histogram2d
and args["z"] is not None
and args["histfunc"] is None
):
args["histfunc"] = trace_patch["histfunc"] = "sum"
if constructor in [go.Histogram2d, go.Densitymapbox]:
show_colorbar = True
trace_patch["coloraxis"] = "coloraxis1"
if "opacity" in args:
if args["opacity"] is None:
if "barmode" in args and args["barmode"] == "overlay":
trace_patch["marker"] = dict(opacity=0.5)
elif constructor in [go.Densitymapbox, go.Pie, go.Funnel, go.Funnelarea]:
trace_patch["opacity"] = args["opacity"]
else:
trace_patch["marker"] = dict(opacity=args["opacity"])
if "line_group" in args:
trace_patch["mode"] = "lines" + ("+markers+text" if args["text"] else "")
elif constructor != go.Splom and (
"symbol" in args or constructor == go.Scattermapbox
):
trace_patch["mode"] = "markers" + ("+text" if args["text"] else "")
if "line_shape" in args:
trace_patch["line"] = dict(shape=args["line_shape"])
if "geojson" in args:
trace_patch["featureidkey"] = args["featureidkey"]
trace_patch["geojson"] = (
args["geojson"]
if not hasattr(args["geojson"], "__geo_interface__") # for geopandas
else args["geojson"].__geo_interface__
)
# Compute marginal attribute: copy to appropriate marginal_*
if "marginal" in args:
position = "marginal_x" if args["orientation"] == "v" else "marginal_y"
other_position = "marginal_x" if args["orientation"] == "h" else "marginal_y"
args[position] = args["marginal"]
args[other_position] = None
# If both marginals and faceting are specified, faceting wins
if args.get("facet_col") is not None and args.get("marginal_y") is not None:
args["marginal_y"] = None
if args.get("facet_row") is not None and args.get("marginal_x") is not None:
args["marginal_x"] = None
# facet_col_wrap only works if no marginals or row faceting is used
if (
args.get("marginal_x") is not None
or args.get("marginal_y") is not None
or args.get("facet_row") is not None
):
args["facet_col_wrap"] = 0
# Compute applicable grouping attributes
for k in group_attrables:
if k in args:
grouped_attrs.append(k)
# Create grouped mappings
grouped_mappings = [make_mapping(args, a) for a in grouped_attrs]
# Create trace specs
trace_specs = make_trace_spec(args, constructor, attrs, trace_patch)
return trace_specs, grouped_mappings, sizeref, show_colorbar
def get_orderings(args, grouper, grouped):
"""
`orders` is the user-supplied ordering with the remaining data-frame-supplied
ordering appended if the column is used for grouping. It includes anything the user
gave, for any variable, including values not present in the dataset. It's a dict
where the keys are e.g. "x" or "color"
`sorted_group_names` is the set of groups, ordered by the order above. It's a list
of tuples like [("value1", ""), ("value2", "")] where each tuple contains the name
of a single dimension-group
"""
orders = {} if "category_orders" not in args else args["category_orders"].copy()
for col in grouper:
if col != one_group:
uniques = list(args["data_frame"][col].unique())
if col not in orders:
orders[col] = uniques
else:
orders[col] = list(OrderedDict.fromkeys(list(orders[col]) + uniques))
sorted_group_names = []
for group_name in grouped.groups:
if len(grouper) == 1:
group_name = (group_name,)
sorted_group_names.append(group_name)
for i, col in reversed(list(enumerate(grouper))):
if col != one_group:
sorted_group_names = sorted(
sorted_group_names,
key=lambda g: orders[col].index(g[i]) if g[i] in orders[col] else -1,
)
return orders, sorted_group_names
def make_figure(args, constructor, trace_patch=None, layout_patch=None):
trace_patch = trace_patch or {}
layout_patch = layout_patch or {}
apply_default_cascade(args)
args = build_dataframe(args, constructor)
if constructor in [go.Treemap, go.Sunburst, go.Icicle] and args["path"] is not None:
args = process_dataframe_hierarchy(args)
if constructor == "timeline":
constructor = go.Bar
args = process_dataframe_timeline(args)
trace_specs, grouped_mappings, sizeref, show_colorbar = infer_config(
args, constructor, trace_patch, layout_patch
)
grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group]
grouped = args["data_frame"].groupby(grouper, sort=False)
orders, sorted_group_names = get_orderings(args, grouper, grouped)
col_labels = []
row_labels = []
nrows = ncols = 1
for m in grouped_mappings:
if m.grouper not in orders:
m.val_map[""] = m.sequence[0]
else:
sorted_values = orders[m.grouper]
if m.facet == "col":
prefix = get_label(args, args["facet_col"]) + "="
col_labels = [prefix + str(s) for s in sorted_values]
ncols = len(col_labels)
if m.facet == "row":
prefix = get_label(args, args["facet_row"]) + "="
row_labels = [prefix + str(s) for s in sorted_values]
nrows = len(row_labels)
for val in sorted_values:
if val not in m.val_map: # always False if it's an IdentityMap
m.val_map[val] = m.sequence[len(m.val_map) % len(m.sequence)]
subplot_type = _subplot_type_for_trace_type(constructor().type)
trace_names_by_frame = {}
frames = OrderedDict()
trendline_rows = []
trace_name_labels = None
facet_col_wrap = args.get("facet_col_wrap", 0)
for group_name in sorted_group_names:
group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0])
mapping_labels = OrderedDict()
trace_name_labels = OrderedDict()
frame_name = ""
for col, val, m in zip(grouper, group_name, grouped_mappings):
if col != one_group:
key = get_label(args, col)
if not isinstance(m.val_map, IdentityMap):
mapping_labels[key] = str(val)
if m.show_in_trace_name:
trace_name_labels[key] = str(val)
if m.variable == "animation_frame":
frame_name = val
trace_name = ", ".join(trace_name_labels.values())
if frame_name not in trace_names_by_frame:
trace_names_by_frame[frame_name] = set()
trace_names = trace_names_by_frame[frame_name]
for trace_spec in trace_specs:
# Create the trace
trace = trace_spec.constructor(name=trace_name)
if trace_spec.constructor not in [
go.Parcats,
go.Parcoords,
go.Choropleth,
go.Choroplethmapbox,
go.Densitymapbox,
go.Histogram2d,
go.Sunburst,
go.Treemap,
go.Icicle,
]:
trace.update(
legendgroup=trace_name,
showlegend=(trace_name != "" and trace_name not in trace_names),
)
if trace_spec.constructor in [go.Bar, go.Violin, go.Box, go.Histogram]:
trace.update(alignmentgroup=True, offsetgroup=trace_name)
trace_names.add(trace_name)
# Init subplot row/col
trace._subplot_row = 1
trace._subplot_col = 1
for i, m in enumerate(grouped_mappings):
val = group_name[i]
try:
m.updater(trace, m.val_map[val]) # covers most cases
except ValueError:
# this catches some odd cases like marginals
if (
trace_spec != trace_specs[0]
and (
trace_spec.constructor in [go.Violin, go.Box]
and m.variable in ["symbol", "pattern"]
)
or (
trace_spec.constructor in [go.Histogram]
and m.variable in ["symbol"]
)
):
pass
elif (
trace_spec != trace_specs[0]
and trace_spec.constructor in [go.Histogram]
and m.variable == "color"
):
trace.update(marker=dict(color=m.val_map[val]))
elif (
trace_spec.constructor in [go.Choropleth, go.Choroplethmapbox]
and m.variable == "color"
):
trace.update(
z=[1] * len(group),
colorscale=[m.val_map[val]] * 2,
showscale=False,
showlegend=True,
)
else:
raise
# Find row for trace, handling facet_row and marginal_x
if m.facet == "row":
row = m.val_map[val]
else:
if (
args.get("marginal_x") is not None # there is a marginal
and trace_spec.marginal != "x" # and we're not it
):
row = 2
else:
row = 1
# Find col for trace, handling facet_col and marginal_y
if m.facet == "col":
col = m.val_map[val]
if facet_col_wrap: # assumes no facet_row, no marginals
row = 1 + ((col - 1) // facet_col_wrap)
col = 1 + ((col - 1) % facet_col_wrap)
else:
if trace_spec.marginal == "y":
col = 2
else:
col = 1
if row > 1:
trace._subplot_row = row
if col > 1:
trace._subplot_col = col
if (
trace_specs[0].constructor == go.Histogram2dContour
and trace_spec.constructor == go.Box
and trace.line.color
):
trace.update(marker=dict(color=trace.line.color))
patch, fit_results = make_trace_kwargs(
args, trace_spec, group, mapping_labels.copy(), sizeref
)
trace.update(patch)
if fit_results is not None:
trendline_rows.append(mapping_labels.copy())
trendline_rows[-1]["px_fit_results"] = fit_results
if frame_name not in frames:
frames[frame_name] = dict(data=[], name=frame_name)
frames[frame_name]["data"].append(trace)
frame_list = [f for f in frames.values()]
if len(frame_list) > 1:
frame_list = sorted(
frame_list, key=lambda f: orders[args["animation_frame"]].index(f["name"])
)
if show_colorbar:
colorvar = "z" if constructor in [go.Histogram2d, go.Densitymapbox] else "color"
range_color = args["range_color"] or [None, None]
colorscale_validator = ColorscaleValidator("colorscale", "make_figure")
layout_patch["coloraxis1"] = dict(
colorscale=colorscale_validator.validate_coerce(
args["color_continuous_scale"]
),
cmid=args["color_continuous_midpoint"],
cmin=range_color[0],
cmax=range_color[1],
colorbar=dict(
title_text=get_decorated_label(args, args[colorvar], colorvar)
),
)
for v in ["height", "width"]:
if args[v]:
layout_patch[v] = args[v]
layout_patch["legend"] = dict(tracegroupgap=0)
if trace_name_labels:
layout_patch["legend"]["title_text"] = ", ".join(trace_name_labels)
if args["title"]:
layout_patch["title_text"] = args["title"]
elif args["template"].layout.margin.t is None:
layout_patch["margin"] = {"t": 60}
if (
"size" in args
and args["size"]
and args["template"].layout.legend.itemsizing is None
):
layout_patch["legend"]["itemsizing"] = "constant"
if facet_col_wrap:
nrows = math.ceil(ncols / facet_col_wrap)
ncols = min(ncols, facet_col_wrap)
if args.get("marginal_x") is not None:
nrows += 1
if args.get("marginal_y") is not None:
ncols += 1
fig = init_figure(
args, subplot_type, frame_list, nrows, ncols, col_labels, row_labels
)
# Position traces in subplots
for frame in frame_list:
for trace in frame["data"]:
if isinstance(trace, go.Splom):
# Special case that is not compatible with make_subplots
continue
_set_trace_grid_reference(
trace,
fig.layout,
fig._grid_ref,
nrows - trace._subplot_row + 1,
trace._subplot_col,
)
# Add traces, layout and frames to figure
fig.add_traces(frame_list[0]["data"] if len(frame_list) > 0 else [])
fig.update_layout(layout_patch)
if "template" in args and args["template"] is not None:
fig.update_layout(template=args["template"], overwrite=True)
fig.frames = frame_list if len(frames) > 1 else []
fig._px_trendlines = pd.DataFrame(trendline_rows)
configure_axes(args, constructor, fig, orders)
configure_animation_controls(args, constructor, fig)
return fig
def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_labels):
# Build subplot specs
specs = [[dict(type=subplot_type or "domain")] * ncols for _ in range(nrows)]
# Default row/column widths uniform
column_widths = [1.0] * ncols
row_heights = [1.0] * nrows
facet_col_wrap = args.get("facet_col_wrap", 0)
# Build column_widths/row_heights
if subplot_type == "xy":
if args.get("marginal_x") is not None:
if args["marginal_x"] == "histogram" or ("color" in args and args["color"]):
main_size = 0.74
else:
main_size = 0.84
row_heights = [main_size] * (nrows - 1) + [1 - main_size]
vertical_spacing = 0.01
elif facet_col_wrap:
vertical_spacing = args.get("facet_row_spacing") or 0.07
else:
vertical_spacing = args.get("facet_row_spacing") or 0.03
if args.get("marginal_y") is not None:
if args["marginal_y"] == "histogram" or ("color" in args and args["color"]):
main_size = 0.74
else:
main_size = 0.84
column_widths = [main_size] * (ncols - 1) + [1 - main_size]
horizontal_spacing = 0.005
else:
horizontal_spacing = args.get("facet_col_spacing") or 0.02
else:
# Other subplot types:
# 'scene', 'geo', 'polar', 'ternary', 'mapbox', 'domain', None
#
# We can customize subplot spacing per type once we enable faceting
# for all plot types
if facet_col_wrap:
vertical_spacing = args.get("facet_row_spacing") or 0.07
else:
vertical_spacing = args.get("facet_row_spacing") or 0.03
horizontal_spacing = args.get("facet_col_spacing") or 0.02
if facet_col_wrap:
subplot_labels = [None] * nrows * ncols
while len(col_labels) < nrows * ncols:
col_labels.append(None)
for i in range(nrows):
for j in range(ncols):
subplot_labels[i * ncols + j] = col_labels[(nrows - 1 - i) * ncols + j]
def _spacing_error_translator(e, direction, facet_arg):
"""
Translates the spacing errors thrown by the underlying make_subplots
routine into one that describes an argument adjustable through px.
"""
if ("%s spacing" % (direction,)) in e.args[0]:
e.args = (
e.args[0]
+ """
Use the {facet_arg} argument to adjust this spacing.""".format(
facet_arg=facet_arg
),
)
raise e
# Create figure with subplots
try:
fig = make_subplots(
rows=nrows,
cols=ncols,
specs=specs,
shared_xaxes="all",
shared_yaxes="all",
row_titles=[] if facet_col_wrap else list(reversed(row_labels)),
column_titles=[] if facet_col_wrap else col_labels,
subplot_titles=subplot_labels if facet_col_wrap else [],
horizontal_spacing=horizontal_spacing,
vertical_spacing=vertical_spacing,
row_heights=row_heights,
column_widths=column_widths,
start_cell="bottom-left",
)
except ValueError as e:
_spacing_error_translator(e, "Horizontal", "facet_col_spacing")
_spacing_error_translator(e, "Vertical", "facet_row_spacing")
# Remove explicit font size of row/col titles so template can take over
for annot in fig.layout.annotations:
annot.update(font=None)
return fig