Skip to content

Commit 6a6f2b3

Browse files
authored
Merge pull request #62 from ttedeschi/subplot-styling
fix subplot styling
2 parents d0c65ee + 81c629e commit 6a6f2b3

File tree

2 files changed

+111
-30
lines changed

2 files changed

+111
-30
lines changed

src/cmsstyle/cmsstyle.py

Lines changed: 90 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1961,8 +1961,14 @@ def plot_common_legend(
19611961
ydown: int | None = None,
19621962
yup: int | None = None,
19631963
title: str = "CMS",
1964-
textalign: int = 11,
1964+
titleFont: int = 62,
1965+
titleSize: float = 50 * 0.75 / 0.6,
1966+
subtitle: str = "Preliminary",
1967+
subtitleFont: str = 52,
1968+
textalign: int = 13,
1969+
ipos: int = 0,
19651970
):
1971+
pad._pad.cd()
19661972
horizontal_margin = (
19671973
self._grid_metadata.pad_horizontal_margin / self._grid_metadata.ncolumns
19681974
)
@@ -1973,23 +1979,61 @@ def plot_common_legend(
19731979

19741980
leg = rt.TLegend(xleft, ydown, xright, yup)
19751981
leg.SetTextAlign(textalign)
1976-
leg.SetHeader(title)
1982+
19771983
leg.SetBorderSize(1)
1984+
leg.SetMargin(0.5)
19781985

19791986
# Have at most 4 items on the same row
19801987
ndrawables = len(args)
1981-
leg.SetNColumns(ndrawables if ndrawables < 5 else 4)
1988+
ncolumns = (ndrawables + 1) if (ndrawables + 1) < 6 else 5
1989+
leg.SetNColumns(ncolumns)
1990+
if ipos != 0:
1991+
n = 0
1992+
for arg in args:
1993+
if n % ncolumns == 0:
1994+
leg.AddEntry(0, " ", " ")
1995+
n += 1
1996+
leg.AddEntry(arg.obj, arg.name, arg.opt)
1997+
n += 1
1998+
else:
1999+
for arg in args:
2000+
leg.AddEntry(arg.obj, arg.name, arg.opt)
19822001

1983-
for arg in args:
1984-
leg.AddEntry(arg.obj, arg.name, arg.opt)
19852002
pad.plot(leg)
19862003

2004+
latex = rt.TLatex()
2005+
latex.SetNDC()
2006+
latex.SetTextFont(titleFont)
2007+
2008+
canvas_height = pad._pad.GetWh()
2009+
ymin = pad._pad.GetYlowNDC()
2010+
ymax = pad._pad.GetYlowNDC() + pad._pad.GetHNDC()
2011+
pad_ndc_height = ymax - ymin
2012+
pad_pixel_height = canvas_height * pad_ndc_height
2013+
titleSize = titleSize / pad_pixel_height
2014+
subtitleSize = titleSize * 0.76
2015+
2016+
latex.SetTextSize(titleSize)
2017+
latex.SetTextAlign(13)
2018+
if ipos != 0:
2019+
latex.DrawLatex(0.11, 0.60, title)
2020+
else:
2021+
latex.DrawLatex(0.10, 0.97, title)
2022+
latex.SetTextFont(subtitleFont)
2023+
latex.SetTextSize(subtitleSize)
2024+
if ipos != 0:
2025+
latex.DrawLatex(0.11, 0.30, subtitle)
2026+
else:
2027+
latex.DrawLatex(0.17, 0.94, subtitle)
2028+
2029+
19872030
def plot_text(
19882031
self,
19892032
pad: CMSPad,
19902033
text,
1991-
textsize=0.1,
1992-
textalign=11,
2034+
textsize=50,
2035+
textfont=42,
2036+
textalign=33,
19932037
xcoord: int | None = None,
19942038
ycoord: int | None = None,
19952039
):
@@ -2008,9 +2052,17 @@ def plot_text(
20082052
latex.SetTextAngle(0)
20092053
latex.SetTextColor(rt.kBlack)
20102054

2011-
latex.SetTextFont(42)
2055+
latex.SetTextFont(textfont)
20122056
latex.SetTextAlign(textalign)
2057+
2058+
canvas_height = pad._pad.GetWh()
2059+
ymin = pad._pad.GetYlowNDC()
2060+
ymax = pad._pad.GetYlowNDC() + pad._pad.GetHNDC()
2061+
pad_ndc_height = ymax - ymin
2062+
pad_pixel_height = canvas_height * pad_ndc_height
2063+
textsize = textsize / pad_pixel_height
20132064
latex.SetTextSize(textsize)
2065+
20142066
latex.DrawLatex(xcoord, ycoord, text)
20152067
latex.Draw()
20162068

@@ -2143,6 +2195,8 @@ def subplots(
21432195
shared_y_axis: bool = True,
21442196
canvas_width: int = 2000,
21452197
canvas_height: int = 2000,
2198+
axis_title_size: float = 50,
2199+
axis_label_size: float = 50 * 0.8
21462200
) -> CMSCanvasManager:
21472201
"""
21482202
Creates multiple pads in a canvas according to the input configuration, then
@@ -2159,6 +2213,8 @@ def subplots(
21592213
- shared_y_axis: whether the y axis of all columns should be shared
21602214
- canvas_width: total width of the canvas
21612215
- canvas_height: total height of the canvas
2216+
- axis_title_size: reference absolute size for axis titles
2217+
- axis_label_size: reference absolute size for axis labels
21622218
"""
21632219

21642220
top_pad = None
@@ -2205,11 +2261,9 @@ def subplots(
22052261
)
22062262
pad.SetBottomMargin(epsilon_height)
22072263
elif row_index == nrows - 1:
2208-
pad.SetTopMargin(epsilon_height)
2209-
pad.SetBottomMargin(
2210-
pad_vertical_margin * (1 / height_ratios[i // ncolumns])
2211-
- epsilon_height
2212-
)
2264+
margin = pad_vertical_margin * (1 / height_ratios[i // ncolumns]) / 2
2265+
pad.SetTopMargin(margin)
2266+
pad.SetBottomMargin(margin)
22132267
else:
22142268
pad.SetTopMargin(
22152269
pad_vertical_margin / 2 * (1 / height_ratios[i // ncolumns])
@@ -2269,22 +2323,37 @@ def subplots(
22692323
for frame, pad in zip(listofframes[-ncolumns:], listofpads[-ncolumns:]):
22702324
with _managed_tpad_context(canvas):
22712325
pad.cd()
2272-
frame.GetXaxis().SetLabelSize(0.3)
2326+
2327+
canvas_height = listofpads[i].GetWh()
2328+
ymin = listofpads[i].GetYlowNDC()
2329+
ymax = listofpads[i].GetYlowNDC() + listofpads[i].GetHNDC()
2330+
pad_ndc_height = ymax - ymin
2331+
pad_pixel_height = canvas_height * pad_ndc_height
2332+
labeltextsize = axis_label_size / pad_pixel_height
2333+
frame.GetXaxis().SetLabelSize(labeltextsize)
22732334
frame.GetXaxis().SetNdivisions(5, 5, 0, True)
22742335

22752336
if shared_y_axis:
22762337
for i in range(0, len(listofframes), ncolumns):
22772338
with _managed_tpad_context(canvas):
22782339
listofpads[i].cd()
2279-
listofframes[i].GetYaxis().SetLabelSize(
2280-
0.3 * (1 / height_ratios[i // ncolumns])
2281-
)
2340+
2341+
canvas_height = listofpads[i].GetWh()
2342+
ymin = listofpads[i].GetYlowNDC()
2343+
ymax = listofpads[i].GetYlowNDC() + listofpads[i].GetHNDC()
2344+
pad_ndc_height = ymax - ymin
2345+
pad_pixel_height = canvas_height * pad_ndc_height
2346+
labeltextsize = axis_label_size / pad_pixel_height
2347+
2348+
listofframes[i].GetYaxis().SetLabelSize(labeltextsize)
22822349
listofframes[i].GetYaxis().SetNdivisions(3, 5, 0, True)
2283-
listofframes[i].GetYaxis().SetTitleSize(
2284-
0.4 * (1 / height_ratios[i // ncolumns])
2285-
)
2350+
2351+
titletextsize = axis_title_size / pad_pixel_height
2352+
listofframes[i].GetYaxis().SetTitleSize(titletextsize)
2353+
2354+
22862355
listofframes[i].GetYaxis().SetTitleOffset(
2287-
2 * (height_ratios[i // ncolumns] / sum(height_ratios))
2356+
3 * (height_ratios[i // ncolumns] / sum(height_ratios))
22882357
)
22892358

22902359
return CMSCanvasManager(

tests/test_subplots.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,32 +59,44 @@ def test_subplots():
5959
"""Example of multiple plots in the same canvas, with shared common legend"""
6060
ncolumns = 2
6161
nrows = 6
62+
6263
cvm = cmsstyle.subplots(
6364
ncolumns=ncolumns,
6465
nrows=nrows,
65-
height_ratios=[3, 1] * (nrows // 2),
66+
height_ratios=[2, 1] * (nrows // 2),
6667
canvas_top_margin=0.1,
67-
canvas_bottom_margin=0.03)
68+
canvas_bottom_margin=0.03,
69+
axis_label_size = 40
70+
)
6871

6972
data, hs, h_err, ratio, yerr_root, ref_line, bkg, signal = _create_drawables()
7073

7174
cvm.plot_common_legend(
7275
cvm.top_pad,
76+
cmsstyle.LegendItem(data, "Uncertainty", "pe"),
77+
cmsstyle.LegendItem(bkg, "MC1", "f"),
78+
cmsstyle.LegendItem(signal, "MC2", "f"),
79+
cmsstyle.LegendItem(ratio, "Ratio", "pe"),
80+
cmsstyle.LegendItem(ratio, "Ratio", "pe"),
81+
cmsstyle.LegendItem(signal, "Testing", "f"),
7382
cmsstyle.LegendItem(data, "Data", "pe"),
7483
cmsstyle.LegendItem(bkg, "MC1", "f"),
7584
cmsstyle.LegendItem(signal, "MC2", "f"),
76-
cmsstyle.LegendItem(ratio, "Ratio", "pe"))
85+
cmsstyle.LegendItem(data, "Hello", "pe"),
86+
cmsstyle.LegendItem(ratio, "BigTitle", "pe"),
87+
textalign=12,
88+
ipos = 11
89+
)
7790
cvm.plot_text(
7891
cvm.top_pad,
7992
"Run 2, 138 fb^{#minus1}",
80-
textsize=0.3,
81-
textalign=33)
93+
)
8294
cvm.plot_text(
8395
cvm.bottom_pad,
8496
"m^{ll} (GeV)",
85-
textsize=1,
86-
textalign=33)
87-
cvm.ylabel("Test")
97+
textsize=50,
98+
)
99+
cvm.ylabel(labels={0:"Test0", 2:"", 4:"Test4", 6:"", 8:"Test8", 10:""})
88100

89101
row_index = -1
90102
for i, pad in enumerate(cvm.pads):
@@ -103,4 +115,4 @@ def test_subplots():
103115

104116

105117
if __name__ == "__main__":
106-
test_subplots()
118+
test_subplots()

0 commit comments

Comments
 (0)