Skip to content

Commit 262f7bc

Browse files
authored
fix: DataFrame plot was raising when some extra keywords were passed to encodings (e.g. x=alt.X(a, axis=alt.Axis(labelAngle=30))) (#18836)
1 parent 4d6c363 commit 262f7bc

File tree

2 files changed

+23
-8
lines changed

2 files changed

+23
-8
lines changed

py-polars/polars/dataframe/plotting.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
from typing import TYPE_CHECKING, Callable, Dict, Union
44

5+
from polars.dependencies import altair as alt
6+
57
if TYPE_CHECKING:
68
import sys
79

8-
import altair as alt
910
from altair.typing import ChannelColor as Color
1011
from altair.typing import ChannelOrder as Order
1112
from altair.typing import ChannelSize as Size
@@ -25,23 +26,29 @@
2526
else:
2627
from typing_extensions import Unpack
2728

28-
Encodings: TypeAlias = Dict[
29-
str,
30-
Union[X, Y, Color, Order, Size, Tooltip],
31-
]
29+
Encoding: TypeAlias = Union[X, Y, Color, Order, Size, Tooltip]
30+
Encodings: TypeAlias = Dict[str, Encoding]
31+
32+
33+
def _maybe_extract_shorthand(encoding: Encoding) -> Encoding:
34+
if isinstance(encoding, alt.SchemaBase):
35+
# e.g. for `alt.X('x:Q', axis=alt.Axis(labelAngle=30))`, return `'x:Q'`
36+
return getattr(encoding, "shorthand", encoding)
37+
return encoding
3238

3339

3440
def _add_tooltip(encodings: Encodings, /, **kwargs: Unpack[EncodeKwds]) -> None:
3541
if "tooltip" not in kwargs:
36-
encodings["tooltip"] = [*encodings.values(), *kwargs.values()] # type: ignore[assignment]
42+
encodings["tooltip"] = [
43+
*[_maybe_extract_shorthand(x) for x in encodings.values()],
44+
*[_maybe_extract_shorthand(x) for x in kwargs.values()], # type: ignore[arg-type]
45+
] # type: ignore[assignment]
3746

3847

3948
class DataFramePlot:
4049
"""DataFrame.plot namespace."""
4150

4251
def __init__(self, df: DataFrame) -> None:
43-
import altair as alt
44-
4552
self._chart = alt.Chart(df)
4653

4754
def bar(

py-polars/tests/unit/operations/namespaces/test_plot.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import altair as alt
2+
13
import polars as pl
24

35

@@ -66,3 +68,9 @@ def test_empty_dataframe() -> None:
6668

6769
def test_nameless_series() -> None:
6870
pl.Series([1, 2, 3]).plot.kde().to_json()
71+
72+
73+
def test_x_with_axis_18830() -> None:
74+
df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]})
75+
result = df.plot.line(x=alt.X("a", axis=alt.Axis(labelAngle=-90))).to_dict()
76+
assert result["encoding"]["tooltip"] == [{"field": "a", "type": "quantitative"}]

0 commit comments

Comments
 (0)