Skip to content

Commit ff02d10

Browse files
Merge pull request #11 from zachessesjohnson/copilot/fix-toml-syntax-error
Fix CI failures: TOML syntax error, asm_spread typo, missing plotting module
2 parents 53f0934 + 4a55275 commit ff02d10

File tree

5 files changed

+407
-2
lines changed

5 files changed

+407
-2
lines changed

.github/workflows/pytest.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ jobs:
2525
python -m pip install -U pip
2626
pip install -r requirements.txt
2727
pip install -e .
28+
pip install -e sovereign_debt_py/
2829
2930
- name: Run tests
3031
run: |

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,3 @@ dev = ["pytest"]
2222
[tool.setuptools.packages.find]
2323
where = ["."]
2424
include = ["sovereign_debt_xl*"]
25-
include = ["sovereign_debt_xl*", "sovereign_debt_py*"]
Lines changed: 349 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,349 @@
1+
"""Pure-Python plotting helpers for sovereign debt analysis.
2+
3+
All functions return ``(fig, ax)`` tuples so callers can further customise
4+
the chart before saving or displaying it. Use :func:`fig_to_png_bytes` to
5+
serialise a figure to PNG bytes.
6+
"""
7+
from __future__ import annotations
8+
9+
import datetime
10+
import io
11+
from typing import Any
12+
13+
import matplotlib
14+
15+
matplotlib.use("Agg") # headless – must come before pyplot import
16+
import matplotlib.pyplot as plt
17+
import matplotlib.dates as mdates
18+
import numpy as np
19+
20+
from .core import coerce_dates, to_1d_array, validate_same_length
21+
22+
__all__ = [
23+
"plot_yield_curve",
24+
"plot_timeseries",
25+
"plot_rolling_average",
26+
"plot_spread",
27+
"plot_fan_chart",
28+
"fig_to_png_bytes",
29+
]
30+
31+
_VALID_STYLES = ("line", "markers", "line+markers")
32+
33+
34+
# ---------------------------------------------------------------------------
35+
# plot_yield_curve
36+
# ---------------------------------------------------------------------------
37+
38+
def plot_yield_curve(
39+
tenors: Any,
40+
yields: Any,
41+
title: str | None = None,
42+
style: str = "line",
43+
fig: Any = None,
44+
ax: Any = None,
45+
) -> tuple[Any, Any]:
46+
"""Plot a yield curve (tenors vs yields) and return ``(fig, ax)``.
47+
48+
Parameters
49+
----------
50+
tenors:
51+
Sequence of numeric tenor values (years).
52+
yields:
53+
Sequence of yield values matching *tenors* in length.
54+
title:
55+
Optional chart title.
56+
style:
57+
One of ``"line"``, ``"markers"``, or ``"line+markers"``.
58+
fig, ax:
59+
Optional existing Matplotlib figure and axes to plot into.
60+
61+
Raises
62+
------
63+
ValueError
64+
If *tenors* and *yields* have different lengths, or *style* is invalid.
65+
"""
66+
t_arr = to_1d_array(tenors)
67+
y_arr = to_1d_array(yields)
68+
validate_same_length(t_arr, y_arr)
69+
if style not in _VALID_STYLES:
70+
raise ValueError(
71+
f"Invalid style {style!r}; must be one of {_VALID_STYLES}"
72+
)
73+
74+
if fig is None or ax is None:
75+
fig, ax = plt.subplots()
76+
77+
use_marker = style in ("markers", "line+markers")
78+
use_line = style in ("line", "line+markers")
79+
ax.plot(
80+
t_arr,
81+
y_arr,
82+
marker="o" if use_marker else None,
83+
linestyle="-" if use_line else "none",
84+
linewidth=2,
85+
)
86+
if title:
87+
ax.set_title(title)
88+
ax.set_xlabel("Tenor")
89+
ax.set_ylabel("Yield")
90+
ax.grid(True, linestyle="--", alpha=0.5)
91+
return fig, ax
92+
93+
94+
# ---------------------------------------------------------------------------
95+
# plot_timeseries
96+
# ---------------------------------------------------------------------------
97+
98+
def plot_timeseries(
99+
dates: Any,
100+
values: Any,
101+
title: str | None = None,
102+
) -> tuple[Any, Any]:
103+
"""Plot a time-series (dates vs values) and return ``(fig, ax)``.
104+
105+
Parameters
106+
----------
107+
dates:
108+
Sequence of date-like values (:class:`datetime.date`,
109+
:class:`datetime.datetime`, or ISO-format strings).
110+
values:
111+
Sequence of numeric values matching *dates* in length.
112+
title:
113+
Optional chart title.
114+
115+
Raises
116+
------
117+
ValueError
118+
If *dates* and *values* have different lengths.
119+
"""
120+
d_list = coerce_dates(dates)
121+
v_arr = to_1d_array(values)
122+
validate_same_length(np.array(d_list), v_arr)
123+
124+
fig, ax = plt.subplots()
125+
x_dt = [datetime.datetime(d.year, d.month, d.day) for d in d_list]
126+
ax.plot(x_dt, v_arr, linewidth=1.5)
127+
if title:
128+
ax.set_title(title)
129+
ax.xaxis.set_major_locator(mdates.AutoDateLocator())
130+
ax.xaxis.set_major_formatter(
131+
mdates.AutoDateFormatter(mdates.AutoDateLocator())
132+
)
133+
fig.autofmt_xdate(rotation=30)
134+
ax.grid(True, linestyle="--", alpha=0.5)
135+
return fig, ax
136+
137+
138+
# ---------------------------------------------------------------------------
139+
# plot_rolling_average
140+
# ---------------------------------------------------------------------------
141+
142+
def plot_rolling_average(
143+
dates: Any,
144+
values: Any,
145+
window: int,
146+
base_label: str = "Original",
147+
roll_label: str | None = None,
148+
) -> tuple[Any, Any]:
149+
"""Plot data with a rolling-average overlay and return ``(fig, ax)``.
150+
151+
Parameters
152+
----------
153+
dates:
154+
Sequence of date-like values.
155+
values:
156+
Sequence of numeric values matching *dates* in length.
157+
window:
158+
Rolling window size (number of periods). Must be ≥ 1 and ≤
159+
``len(values)``.
160+
base_label:
161+
Legend label for the raw series (default ``"Original"``).
162+
roll_label:
163+
Legend label for the rolling-average series. Defaults to
164+
``f"Rolling {window}"``.
165+
166+
Raises
167+
------
168+
ValueError
169+
If *dates* and *values* have different lengths, or *window* is larger
170+
than the data length.
171+
"""
172+
d_list = coerce_dates(dates)
173+
v_arr = to_1d_array(values)
174+
validate_same_length(np.array(d_list), v_arr)
175+
window = int(window)
176+
if window < 1:
177+
raise ValueError(f"window must be >= 1 (got {window})")
178+
if window > len(v_arr):
179+
raise ValueError(
180+
f"window ({window}) is larger than the data length ({len(v_arr)})"
181+
)
182+
if roll_label is None:
183+
roll_label = f"Rolling {window}"
184+
185+
rolling = np.full_like(v_arr, np.nan, dtype=float)
186+
for i in range(window - 1, len(v_arr)):
187+
rolling[i] = float(np.mean(v_arr[i - window + 1: i + 1]))
188+
189+
fig, ax = plt.subplots()
190+
x_dt = [datetime.datetime(d.year, d.month, d.day) for d in d_list]
191+
ax.plot(x_dt, v_arr, color="lightsteelblue", linewidth=1.0, alpha=0.7, label=base_label)
192+
ax.plot(x_dt, rolling, color="steelblue", linewidth=2.0, label=roll_label)
193+
ax.xaxis.set_major_locator(mdates.AutoDateLocator())
194+
ax.xaxis.set_major_formatter(
195+
mdates.AutoDateFormatter(mdates.AutoDateLocator())
196+
)
197+
fig.autofmt_xdate(rotation=30)
198+
ax.legend(loc="best")
199+
ax.grid(True, linestyle="--", alpha=0.5)
200+
return fig, ax
201+
202+
203+
# ---------------------------------------------------------------------------
204+
# plot_spread
205+
# ---------------------------------------------------------------------------
206+
207+
def plot_spread(
208+
x: Any,
209+
series_a: Any,
210+
series_b: Any,
211+
label_a: str = "Series A",
212+
label_b: str = "Series B",
213+
title: str | None = None,
214+
) -> tuple[Any, Any]:
215+
"""Plot two series and their spread and return ``(fig, ax)``.
216+
217+
Parameters
218+
----------
219+
x:
220+
Sequence of x-axis values (dates or numerics).
221+
series_a, series_b:
222+
Sequences of numeric values, each matching *x* in length.
223+
label_a, label_b:
224+
Legend labels.
225+
title:
226+
Optional chart title.
227+
228+
Raises
229+
------
230+
ValueError
231+
If any of the three sequences have different lengths.
232+
"""
233+
a_arr = to_1d_array(series_a)
234+
b_arr = to_1d_array(series_b)
235+
236+
# x may be dates or numerics
237+
try:
238+
x_vals: Any = coerce_dates(x)
239+
x_plot = [datetime.datetime(d.year, d.month, d.day) for d in x_vals]
240+
use_dates = True
241+
except (ValueError, TypeError):
242+
x_plot = list(x)
243+
use_dates = False
244+
245+
validate_same_length(np.array(x_plot), a_arr)
246+
validate_same_length(a_arr, b_arr)
247+
248+
fig, ax = plt.subplots()
249+
ax.plot(x_plot, a_arr, label=label_a, linewidth=1.5)
250+
ax.plot(x_plot, b_arr, label=label_b, linewidth=1.5)
251+
if use_dates:
252+
ax.xaxis.set_major_locator(mdates.AutoDateLocator())
253+
ax.xaxis.set_major_formatter(
254+
mdates.AutoDateFormatter(mdates.AutoDateLocator())
255+
)
256+
fig.autofmt_xdate(rotation=30)
257+
if title:
258+
ax.set_title(title)
259+
ax.legend(loc="best")
260+
ax.grid(True, linestyle="--", alpha=0.5)
261+
return fig, ax
262+
263+
264+
# ---------------------------------------------------------------------------
265+
# plot_fan_chart
266+
# ---------------------------------------------------------------------------
267+
268+
def plot_fan_chart(
269+
x: Any,
270+
p50: Any,
271+
bands: dict[tuple[float, float], tuple[Any, Any]],
272+
title: str | None = None,
273+
) -> tuple[Any, Any]:
274+
"""Plot a fan chart with confidence bands and return ``(fig, ax)``.
275+
276+
Parameters
277+
----------
278+
x:
279+
Sequence of x-axis values (e.g. years).
280+
p50:
281+
Sequence of median values matching *x* in length.
282+
bands:
283+
Mapping of ``(low_prob, high_prob)`` → ``(lower_series, upper_series)``
284+
where each series matches *x* in length.
285+
title:
286+
Optional chart title.
287+
288+
Raises
289+
------
290+
ValueError
291+
If *x* and *p50* (or any band series) have different lengths.
292+
"""
293+
x_arr = to_1d_array(x)
294+
p50_arr = to_1d_array(p50)
295+
validate_same_length(x_arr, p50_arr)
296+
297+
fig, ax = plt.subplots()
298+
for (lo_prob, hi_prob), (lower, upper) in bands.items():
299+
lo_arr = to_1d_array(lower)
300+
hi_arr = to_1d_array(upper)
301+
validate_same_length(x_arr, lo_arr)
302+
validate_same_length(x_arr, hi_arr)
303+
label = f"{int(lo_prob * 100)}{int(hi_prob * 100)}%"
304+
ax.fill_between(x_arr, lo_arr, hi_arr, alpha=0.3, label=label)
305+
ax.plot(x_arr, p50_arr, linewidth=2, label="Median", color="steelblue")
306+
if title:
307+
ax.set_title(title)
308+
ax.legend(loc="best")
309+
ax.grid(True, linestyle="--", alpha=0.5)
310+
return fig, ax
311+
312+
313+
# ---------------------------------------------------------------------------
314+
# fig_to_png_bytes
315+
# ---------------------------------------------------------------------------
316+
317+
def fig_to_png_bytes(
318+
fig: Any,
319+
width_px: int | None = None,
320+
height_px: int | None = None,
321+
dpi: int = 100,
322+
close: bool = False,
323+
) -> bytes:
324+
"""Render *fig* to PNG bytes.
325+
326+
Parameters
327+
----------
328+
fig:
329+
Matplotlib figure to render.
330+
width_px, height_px:
331+
Optional output dimensions in pixels. When provided the figure size
332+
is adjusted before rendering.
333+
dpi:
334+
Dots per inch for the output PNG (default 100).
335+
close:
336+
If ``True``, close *fig* after rendering to free memory.
337+
338+
Returns
339+
-------
340+
bytes
341+
PNG-encoded image bytes.
342+
"""
343+
if width_px is not None and height_px is not None:
344+
fig.set_size_inches(width_px / dpi, height_px / dpi)
345+
buf = io.BytesIO()
346+
fig.savefig(buf, format="png", dpi=dpi, bbox_inches="tight")
347+
if close:
348+
plt.close(fig)
349+
return buf.getvalue()

0 commit comments

Comments
 (0)