Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions frontend/src/components/editor/Output.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,17 @@ export const OutputRenderer: React.FC<{
case "image/bmp":
case "image/gif":
case "image/jpeg":
case "image/svg+xml":
invariant(
typeof data === "string",
`Expected string data for mime=${mimetype}. Got ${typeof data}`,
);
if (
mimetype === "image/svg+xml" &&
!data.startsWith("data:image/svg+xml;base64,")
) {
return renderHTML({ html: data, alwaysSanitizeHtml: true });
}
return (
<ImageOutput
className={channel}
Expand All @@ -154,12 +161,6 @@ export const OutputRenderer: React.FC<{
height={metadata?.height}
/>
);
case "image/svg+xml":
invariant(
typeof data === "string",
`Expected string data for mime=${mimetype}. Got ${typeof data}`,
);
return renderHTML({ html: data, alwaysSanitizeHtml: true });

case "video/mp4":
case "video/mpeg":
Expand Down
59 changes: 59 additions & 0 deletions frontend/src/components/editor/__tests__/Output.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,62 @@ describe("OutputRenderer renderFallback prop", () => {
).toBeInTheDocument();
});
});

describe("OutputRenderer image and SVG rendering", () => {
const plainSvgString =
'<svg><rect x="0" y="0" width="10" height="10"></rect></svg>';
const base64SvgDataUrl =
"";
const base64PngDataUrl =
"";

it("should render plain SVG string via renderHTML", () => {
const { container } = render(
<OutputRenderer
message={{
channel: "output",
data: plainSvgString,
mimetype: "image/svg+xml",
}}
/>,
);
const svgElement = container.querySelector("svg");
expect(svgElement).not.toBeNull();
const rectElement = svgElement!.querySelector("rect");
expect(rectElement).not.toBeNull();
const imgElement = container.querySelector("img");
expect(imgElement).toBeNull();
});

it("should render Base64 SVG data URL via ImageOutput", () => {
const { container } = render(
<OutputRenderer
message={{
channel: "output",
data: base64SvgDataUrl,
mimetype: "image/svg+xml",
}}
/>,
);
const imgElement = container.querySelector("img");
expect(imgElement).not.toBeNull();
expect(imgElement).toHaveAttribute("src", base64SvgDataUrl);
const svgElement = container.querySelector("svg");
expect(svgElement).toBeNull();
});

it("should render Base64 PNG data URL via ImageOutput", () => {
const { container } = render(
<OutputRenderer
message={{
channel: "output",
data: base64PngDataUrl,
mimetype: "image/png",
}}
/>,
);
const imgElement = container.querySelector("img");
expect(imgElement).not.toBeNull();
expect(imgElement).toHaveAttribute("src", base64PngDataUrl);
});
});
5 changes: 5 additions & 0 deletions marimo/_convert/ipynb/from_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from __future__ import annotations

import base64
import io
import json
import re
Expand Down Expand Up @@ -191,6 +192,10 @@ def _add_marimo_metadata(


def _maybe_extract_dataurl(data: Any) -> Any:
if isinstance(data, str) and data.startswith("data:image/svg+xml;base64,"):
# Decode SVG from base64 to plain text XML
payload = data[len("data:image/svg+xml;base64,") :]
return base64.b64decode(payload).decode()
if (
isinstance(data, str)
and data.startswith("data:")
Expand Down
13 changes: 11 additions & 2 deletions marimo/_output/mpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,20 @@ def _render_figure_mimebundle(
fig: Matplotlib figure canvas to render

Returns:
Tuple of (mimetype, json_data) where json_data is a mimebundle
containing the PNG data URL and display metadata
Tuple of (mimetype, data). If `matplotlib.rcParams["savefig.format"]` is 'svg',
mimetype is 'image/svg+xml' and data is the Base64-encoded SVG data URL.
Otherwise, mimetype is 'application/vnd.marimo+mimebundle' and data is a JSON string
representing a mimebundle containing the PNG data URL and display metadata.
"""
buf = io.BytesIO()

if plt.rcParams["savefig.format"] == "svg":
fig.figure.savefig(buf, format="svg", bbox_inches="tight") # type: ignore[attr-defined]
svg_bytes = buf.getvalue()
plot_bytes = base64.b64encode(svg_bytes)
data_url = build_data_url(mimetype="image/svg+xml", data=plot_bytes)
return "image/svg+xml", data_url

# Get current DPI and double it for retina display (like Jupyter)
original_dpi = fig.figure.dpi # type: ignore[attr-defined]
retina_dpi = original_dpi * 2
Expand Down
6 changes: 6 additions & 0 deletions tests/_convert/ipynb/test_from_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ def __():
"",
"iVBORw0KGgoAAAANSUhEUgAAAAUA",
),
# SVG string from Base64 data URL
(
"",
"<svg></svg>",
),
# Non-data-URL string passes through
("hello world", "hello world"),
# Dict passes through
Expand All @@ -99,6 +104,7 @@ def __():
],
ids=[
"base64_data_url",
"svg_string_from_base64_data_url",
"regular_string",
"dict_passthrough",
"int_passthrough",
Expand Down
37 changes: 37 additions & 0 deletions tests/_output/formatters/test_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,40 @@ async def test_matplotlib_backwards_compatibility(
assert mime_type == "application/vnd.marimo+mimebundle"
mimebundle = json.loads(data)
assert "image/png" in mimebundle


@pytest.mark.skipif(not HAS_MPL, reason="optional dependencies not installed")
async def test_matplotlib_svg_rendering(
executing_kernel: Kernel, exec_req: ExecReqProvider
) -> None:
"""Test that matplotlib figures are rendered in SVG format."""
from marimo._output.formatters.formatters import register_formatters

register_formatters(theme="light")

await executing_kernel.run(
[
exec_req.get(
"""
import matplotlib.pyplot as plt

fmt = plt.rcParams["savefig.format"]
plt.rcParams["savefig.format"] = "svg"

# Create a simple figure
fig, ax = plt.subplots(figsize=(4, 3))
ax.plot([1, 2, 3], [1, 2, 3])
result = fig._mime_()

plt.rcParams["savefig.format"] = fmt
"""
)
]
)

# Get the formatted result from kernel globals
mime_type, data = executing_kernel.globals["result"]

assert mime_type == "image/svg+xml"
assert isinstance(data, str)
assert data.startswith("")
Loading