Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions databao/core/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import ABC, abstractmethod
from typing import Any

from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field

from databao.core.executor import ExecutionResult

Expand All @@ -24,9 +24,22 @@ class VisualisationResult(BaseModel):
plot: Any | None
code: str | None

visualizer: "Visualizer | None" = Field(exclude=True)
"""Reference to the Visualizer that produced this result. Not serializable."""

# Immutable model; allow arbitrary plot types (e.g., matplotlib objects)
model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)

def edit(self, request: str, *, stream: bool = False) -> "VisualisationResult":
"""Edit this visualization with a natural language request.

Syntactic sugar for the `Visualizer.edit` method.
"""
if self.visualizer is None:
# Forbid using `.edit` after deserialization
raise RuntimeError("Visualizer is not set")
return self.visualizer.edit(request, self, stream=stream)

def _repr_mimebundle_(self, include: Any = None, exclude: Any = None) -> Any:
"""Return MIME bundle for IPython notebooks."""
# See docs for the behavior of magic methods https://ipython.readthedocs.io/en/stable/config/integrating.html#custom-methods
Expand Down Expand Up @@ -79,12 +92,14 @@ def _get_plot_html(self) -> str | None:


class Visualizer(ABC):
"""Abstract interface for converting data into plots/text.

Implementations may ignore the request and choose an appropriate visualization.
"""
"""Abstract interface for converting data into plots using natural language."""

@abstractmethod
def visualize(self, request: str | None, data: ExecutionResult, *, stream: bool = True) -> VisualisationResult:
def visualize(self, request: str | None, data: ExecutionResult, *, stream: bool = False) -> VisualisationResult:
"""Produce a visualization for the given data and optional user request."""
pass

@abstractmethod
def edit(self, request: str, visualization: VisualisationResult, *, stream: bool = False) -> VisualisationResult:
"""Refine a prior visualization with a natural language request."""
pass
7 changes: 5 additions & 2 deletions databao/visualizers/dumb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@


class DumbVisualizer(Visualizer):
def visualize(self, request: str | None, data: ExecutionResult, *, stream: bool = True) -> VisualisationResult:
def visualize(self, request: str | None, data: ExecutionResult, *, stream: bool = False) -> VisualisationResult:
plot = data.df.plot(kind="bar") if data.df is not None else None
return VisualisationResult(text="", meta={}, plot=plot, code="")
return VisualisationResult(text="", meta={}, plot=plot, code="", visualizer=self)

def edit(self, request: str, visualization: VisualisationResult, *, stream: bool = False) -> VisualisationResult:
return visualization
77 changes: 47 additions & 30 deletions databao/visualizers/vega_chat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import dataclasses
import io
import json
import logging
Expand All @@ -9,7 +8,7 @@
from edaplot.image_utils import vl_to_png_bytes
from edaplot.llms import LLMConfig as VegaLLMConfig
from edaplot.vega import to_altair_chart
from edaplot.vega_chat.vega_chat import VegaChat, VegaChatConfig, VegaChatState
from edaplot.vega_chat.vega_chat import MessageInfo, VegaChatConfig, VegaChatGraph, VegaChatState
from langchain_core.runnables import RunnableConfig
from PIL import Image

Expand All @@ -26,6 +25,7 @@ class VegaChatResult(VisualisationResult):
spec: dict[str, Any] | None = None
spec_df: pd.DataFrame | None = None

# TODO expose as part of the VisualisationResult API
def interactive(self) -> VegaVisTool | None:
"""Return an interactive UI wizard for the Vega-Lite chart.

Expand Down Expand Up @@ -76,29 +76,11 @@ def __init__(self, llm_config: LLMConfig, *, return_interactive_chart: bool = Fa
)
self._return_interactive_chart = return_interactive_chart

def visualize(self, request: str | None, data: ExecutionResult, *, stream: bool = True) -> VegaChatResult:
if data.df is None:
return VegaChatResult(text="Nothing to visualize", meta={}, plot=None, code=None)

if request is None:
# We could also call the ChartRecommender module, but since we want a
# single output plot, we'll just use a simple prompt.
request = (
"I don't know what the data is about. Show me an interesting plot. Don't show the same plot twice."
)

vega_chat = VegaChat.from_config(config=self._vega_config, df=data.df)
start_state, compiled_graph = vega_chat.start_query(request, is_async=False)
# Use an empty `config` instead of `None` due to a bug in the "AI Agents Debugger" PyCharm plugin.
final_state: VegaChatState = GraphExecutor._invoke_graph_sync(
compiled_graph, start_state, config=RunnableConfig(), stream=stream
)
model_out = vega_chat.submit_query(final_state)

def _process_result(self, state: VegaChatState, spec_df: pd.DataFrame) -> VegaChatResult:
# Use the possibly transformed dataframe tied to the generated spec
preprocessed_df = vega_chat.dataframe
model_out = state["messages"][-1]
text = model_out.message.text()
meta = dataclasses.asdict(model_out)
meta = {"messages": state["messages"]} # Full history. Also used for edit follow ups.
spec = model_out.spec
spec_json = json.dumps(spec, indent=2) if spec is not None else None
if spec is None or not model_out.is_drawable or model_out.is_empty_chart:
Expand All @@ -108,7 +90,8 @@ def visualize(self, request: str | None, data: ExecutionResult, *, stream: bool
plot=None,
code=spec_json,
spec=spec,
spec_df=preprocessed_df,
spec_df=spec_df,
visualizer=self,
)

if not model_out.is_valid_schema and model_out.is_drawable:
Expand All @@ -117,8 +100,8 @@ def visualize(self, request: str | None, data: ExecutionResult, *, stream: bool
logger.warning("Generated Vega-Lite spec is not valid, but it is still drawable: %s", spec_json)
if self._return_interactive_chart:
# The VegaVisTool backend uses vega-embed so it can handle corrupt specs
plot = VegaVisTool(spec, preprocessed_df)
elif (png_bytes := vl_to_png_bytes(spec, preprocessed_df)) is not None:
plot = VegaVisTool(spec, spec_df)
elif (png_bytes := vl_to_png_bytes(spec, spec_df)) is not None:
# Try to convert to an Image that can still be displayed in Jupyter notebooks
plot = Image.open(io.BytesIO(png_bytes))
else:
Expand All @@ -128,18 +111,52 @@ def visualize(self, request: str | None, data: ExecutionResult, *, stream: bool
plot=None,
code=spec_json,
spec=spec,
spec_df=preprocessed_df,
spec_df=spec_df,
visualizer=self,
)
elif self._return_interactive_chart:
plot = VegaVisTool(spec, preprocessed_df)
plot = VegaVisTool(spec, spec_df)
else:
plot = to_altair_chart(spec, preprocessed_df)
plot = to_altair_chart(spec, spec_df)

return VegaChatResult(
text=text,
meta=meta,
plot=plot,
code=spec_json,
spec=spec,
spec_df=preprocessed_df,
spec_df=spec_df,
visualizer=self,
)

def _run_vega_chat(
self, request: str, df: pd.DataFrame, *, messages: list[MessageInfo] | None = None, stream: bool = False
) -> VegaChatResult:
vega_chat = VegaChatGraph(self._vega_config, df=df)
start_state = vega_chat.get_start_state(request, messages=messages)
compiled_graph = vega_chat.compile_graph(is_async=False)
# Use an empty `config` instead of `None` due to a bug in the "AI Agents Debugger" PyCharm plugin.
final_state: VegaChatState = GraphExecutor._invoke_graph_sync(
compiled_graph, start_state, config=RunnableConfig(), stream=stream
)
processed_df = vega_chat.dataframe
return self._process_result(final_state, processed_df)

def visualize(self, request: str | None, data: ExecutionResult, *, stream: bool = False) -> VegaChatResult:
if data.df is None:
return VegaChatResult(text="Nothing to visualize", meta={}, plot=None, code=None, visualizer=self)
if request is None:
# We could also call the ChartRecommender module, but since we want a
# single output plot, we'll just use a simple prompt.
request = "I don't know what the data is about. Show me an interesting plot."
return self._run_vega_chat(request, data.df, stream=stream)

def edit(self, request: str, visualization: VisualisationResult, *, stream: bool = False) -> VegaChatResult:
if not isinstance(visualization, VegaChatResult):
raise ValueError(f"{self.__class__.__name__} can only edit {VegaChatResult.__name__} objects")
if visualization.spec_df is None:
raise ValueError("No dataframe found in the provided visualization")
messages = visualization.meta.get("messages", None)
if messages is None:
raise ValueError("No message history found in the provided visualization")
return self._run_vega_chat(request, visualization.spec_df, messages=messages, stream=stream)
2 changes: 1 addition & 1 deletion tests/test_vega_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def _make_result(**kwargs: Any) -> VegaChatResult:

Allows overriding/adding fields via kwargs.
"""
base: dict[str, Any] = dict(text="", meta={}, plot=None, code=None)
base: dict[str, Any] = dict(text="", meta={}, plot=None, code=None, visualizer=None)
base.update(kwargs)
return VegaChatResult(**base)

Expand Down
8 changes: 4 additions & 4 deletions tests/test_visualisation_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@


def test_visualisation_result_get_plot_html_with_no_plot() -> None:
result = VisualisationResult(text="Test", meta={}, plot=None, code=None)
result = VisualisationResult(text="Test", meta={}, plot=None, code=None, visualizer=None)
assert result._repr_mimebundle_() is None


def test_visualisation_result_get_plot_html_with_invalid_plot() -> None:
class InvalidPlot:
pass

result = VisualisationResult(text="Test", meta={}, plot=InvalidPlot(), code=None)
result = VisualisationResult(text="Test", meta={}, plot=InvalidPlot(), code=None, visualizer=None)
assert result._repr_mimebundle_() is None


Expand All @@ -35,7 +35,7 @@ def test_visualisation_result_altair() -> None:
.interactive()
)

result = VisualisationResult(text="Test representation", meta={}, plot=chart, code=None)
result = VisualisationResult(text="Test representation", meta={}, plot=chart, code=None, visualizer=None)
assert result._repr_mimebundle_() is not None


Expand All @@ -49,5 +49,5 @@ def test_visualisation_result_matplotlib() -> None:
ax.set_title("Simple plot")
ax.set_ylabel("Y axis")
ax.set_xlabel("X axis")
result = VisualisationResult(text="Test representation", meta={}, plot=fig, code=None)
result = VisualisationResult(text="Test representation", meta={}, plot=fig, code=None, visualizer=None)
assert result._repr_mimebundle_() is not None