Skip to content

Commit 83faec8

Browse files
committed
Support trace-specific color sequences in Plotly Express via templates
- Modify apply_default_cascade to check template.data.<trace_type> for marker.color or line.color - Fallback to template.layout.colorway if trace-specific colors not found - Add comprehensive tests for trace-specific color sequences - Handle timeline special case (maps to bar trace type) - Follow existing patterns for symbol_sequence and line_dash_sequence Fixes #5416
1 parent 965b3bd commit 83faec8

File tree

2 files changed

+94
-5
lines changed

2 files changed

+94
-5
lines changed

plotly/express/_core.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,28 +1037,36 @@ def apply_default_cascade(args, constructor=None):
10371037
if args["color_continuous_scale"] is None:
10381038
args["color_continuous_scale"] = sequential.Viridis
10391039

1040+
# if color_discrete_sequence not set explicitly or in px.defaults,
1041+
# see if we can defer to template. Try trace-specific colors first,
1042+
# then layout.colorway, then set reasonable defaults
10401043
if "color_discrete_sequence" in args:
10411044
if args["color_discrete_sequence"] is None and constructor is not None:
10421045
if constructor == "timeline":
10431046
trace_type = "bar"
10441047
else:
10451048
trace_type = constructor().type
10461049
if trace_data_list := getattr(args["template"].data, trace_type, None):
1047-
collected_colors = [
1050+
# try marker.color first
1051+
args["color_discrete_sequence"] = [
10481052
trace_data.marker.color
10491053
for trace_data in trace_data_list
10501054
if hasattr(trace_data, "marker")
10511055
]
1052-
if not collected_colors:
1053-
collected_colors = [
1056+
# fallback to line.color if marker.color not available
1057+
if not args["color_discrete_sequence"] or not any(args["color_discrete_sequence"]):
1058+
args["color_discrete_sequence"] = [
10541059
trace_data.line.color
10551060
for trace_data in trace_data_list
10561061
if hasattr(trace_data, "line")
10571062
]
1058-
if collected_colors:
1059-
args["color_discrete_sequence"] = collected_colors
1063+
# if no trace-specific colors found, reset to None to allow fallback
1064+
if not args["color_discrete_sequence"] or not any(args["color_discrete_sequence"]):
1065+
args["color_discrete_sequence"] = None
1066+
# fallback to layout.colorway if trace-specific colors not available
10601067
if args["color_discrete_sequence"] is None and args["template"].layout.colorway:
10611068
args["color_discrete_sequence"] = args["template"].layout.colorway
1069+
# final fallback to default qualitative palette
10621070
if args["color_discrete_sequence"] is None:
10631071
args["color_discrete_sequence"] = qualitative.D3
10641072

tests/test_optional/test_px/test_px.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,87 @@ def test_px_templates(backend):
226226
pio.templates.default = "plotly"
227227

228228

229+
def test_px_templates_trace_specific_colors(backend):
230+
import plotly.graph_objects as go
231+
232+
tips = px.data.tips(return_type=backend)
233+
234+
# read trace-specific colors from template.data.histogram
235+
histogram_template = go.layout.Template()
236+
histogram_template.data.histogram = [
237+
go.Histogram(marker=dict(color="orange")),
238+
go.Histogram(marker=dict(color="purple")),
239+
]
240+
fig = px.histogram(tips, x="total_bill", color="sex", template=histogram_template)
241+
assert fig.data[0].marker.color == "orange"
242+
assert fig.data[1].marker.color == "purple"
243+
244+
# scatter uses template.data.scatter colors, not histogram colors
245+
scatter_template = go.layout.Template()
246+
scatter_template.data.scatter = [
247+
go.Scatter(marker=dict(color="cyan")),
248+
go.Scatter(marker=dict(color="magenta")),
249+
]
250+
scatter_template.data.histogram = [
251+
go.Histogram(marker=dict(color="orange")),
252+
]
253+
fig = px.scatter(tips, x="total_bill", y="tip", color="sex", template=scatter_template)
254+
assert fig.data[0].marker.color == "cyan"
255+
assert fig.data[1].marker.color == "magenta"
256+
257+
# histogram still uses histogram colors even when scatter colors exist
258+
fig = px.histogram(tips, x="total_bill", color="sex", template=scatter_template)
259+
assert fig.data[0].marker.color == "orange"
260+
261+
# fallback to layout.colorway when trace-specific colors don't exist
262+
fig = px.histogram(
263+
tips, x="total_bill", color="sex", template=dict(layout_colorway=["yellow", "green"])
264+
)
265+
assert fig.data[0].marker.color == "yellow"
266+
assert fig.data[1].marker.color == "green"
267+
268+
# timeline special case (maps to bar)
269+
timeline_template = go.layout.Template()
270+
timeline_template.data.bar = [
271+
go.Bar(marker=dict(color="red")),
272+
go.Bar(marker=dict(color="blue")),
273+
]
274+
timeline_data = {
275+
"Task": ["Job A", "Job B"],
276+
"Start": ["2009-01-01", "2009-03-05"],
277+
"Finish": ["2009-02-28", "2009-04-15"],
278+
"Resource": ["Alex", "Max"],
279+
}
280+
# Use same backend as tips for consistency
281+
df_timeline = px.data.tips(return_type=backend)
282+
df_timeline = nw.from_native(df_timeline).with_columns(
283+
nw.lit("Job A").alias("Task"),
284+
nw.lit("2009-01-01").alias("Start"),
285+
nw.lit("2009-02-28").alias("Finish"),
286+
nw.lit("Alex").alias("Resource"),
287+
).head(1).to_native()
288+
# Add second row
289+
df_timeline2 = nw.from_native(df_timeline).with_columns(
290+
nw.lit("Job B").alias("Task"),
291+
nw.lit("2009-03-05").alias("Start"),
292+
nw.lit("2009-04-15").alias("Finish"),
293+
nw.lit("Max").alias("Resource"),
294+
).head(1).to_native()
295+
# Combine - actually, this is getting too complex. Let me just use a simpler approach
296+
import pandas as pd
297+
df_timeline = pd.DataFrame(timeline_data)
298+
fig = px.timeline(
299+
df_timeline,
300+
x_start="Start",
301+
x_end="Finish",
302+
y="Task",
303+
color="Resource",
304+
template=timeline_template,
305+
)
306+
assert fig.data[0].marker.color == "red"
307+
assert fig.data[1].marker.color == "blue"
308+
309+
229310
def test_px_defaults():
230311
px.defaults.labels = dict(x="hey x")
231312
px.defaults.category_orders = dict(color=["b", "a"])

0 commit comments

Comments
 (0)