Skip to content

Commit 1bd4269

Browse files
committed
[Data] Add map namespace support for expression operations
Signed-off-by: Hsien-Cheng Huang <ryankert01@gmail.com>
1 parent 9e2de8d commit 1bd4269

File tree

3 files changed

+281
-0
lines changed

3 files changed

+281
-0
lines changed

python/ray/data/expressions.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
if TYPE_CHECKING:
2929
from ray.data.namespace_expressions.dt_namespace import _DatetimeNamespace
3030
from ray.data.namespace_expressions.list_namespace import _ListNamespace
31+
from ray.data.namespace_expressions.map_namespace import _MapNamespace
3132
from ray.data.namespace_expressions.string_namespace import _StringNamespace
3233
from ray.data.namespace_expressions.struct_namespace import _StructNamespace
3334

@@ -634,6 +635,13 @@ def struct(self) -> "_StructNamespace":
634635

635636
return _StructNamespace(self)
636637

638+
@property
639+
def map(self) -> "_MapNamespace":
640+
"""Access map/dict operations for this expression."""
641+
from ray.data.namespace_expressions.map_namespace import _MapNamespace
642+
643+
return _MapNamespace(self)
644+
637645
@property
638646
def dt(self) -> "_DatetimeNamespace":
639647
"""Access datetime operations for this expression."""
@@ -1481,6 +1489,7 @@ def download(uri_column_name: str) -> DownloadExpr:
14811489
"_ListNamespace",
14821490
"_StringNamespace",
14831491
"_StructNamespace",
1492+
"_MapNamespace",
14841493
"_DatetimeNamespace",
14851494
]
14861495

@@ -1499,6 +1508,10 @@ def __getattr__(name: str):
14991508
from ray.data.namespace_expressions.struct_namespace import _StructNamespace
15001509

15011510
return _StructNamespace
1511+
elif name == "_MapNamespace":
1512+
from ray.data.namespace_expressions.map_namespace import _MapNamespace
1513+
1514+
return _MapNamespace
15021515
elif name == "_DatetimeNamespace":
15031516
from ray.data.namespace_expressions.dt_namespace import _DatetimeNamespace
15041517

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from enum import Enum
5+
from typing import TYPE_CHECKING
6+
7+
import pyarrow
8+
import pyarrow.compute as pc
9+
10+
from ray.data.datatype import DataType
11+
from ray.data.expressions import pyarrow_udf
12+
13+
if TYPE_CHECKING:
14+
from ray.data.expressions import Expr, UDFExpr
15+
16+
17+
class MapComponent(str, Enum):
18+
KEYS = "keys"
19+
VALUES = "values"
20+
21+
22+
def _extract_map_component(
23+
arr: pyarrow.Array, component: MapComponent
24+
) -> pyarrow.Array:
25+
"""
26+
Extracts keys or values from a MapArray or ListArray<Struct>.
27+
28+
This serves as the primary implementation since PyArrow does not yet
29+
expose dedicated compute kernels for map projection in the Python API.
30+
"""
31+
# 1. Handle Chunked Arrays (Recursion)
32+
if isinstance(arr, pyarrow.ChunkedArray):
33+
return pyarrow.chunked_array(
34+
[_extract_map_component(chunk, component) for chunk in arr.chunks]
35+
)
36+
37+
child_array = None
38+
39+
# Case 1: MapArray
40+
if isinstance(arr, pyarrow.MapArray):
41+
if component == MapComponent.KEYS:
42+
child_array = arr.keys
43+
else:
44+
child_array = arr.items
45+
46+
# Case 2: ListArray<Struct<Key, Value>>
47+
elif isinstance(arr, pyarrow.ListArray):
48+
flat_values = arr.values
49+
if (
50+
isinstance(flat_values, pyarrow.StructArray)
51+
and flat_values.type.num_fields >= 2
52+
):
53+
idx = 0 if component == MapComponent.KEYS else 1
54+
child_array = flat_values.field(idx)
55+
56+
# Reconstruct ListArray & Normalize Offsets
57+
offsets = arr.offsets
58+
if len(offsets) > 0: # Handle offsets changes
59+
start_offset = offsets[0]
60+
if start_offset.as_py() != 0:
61+
# Slice child_array to match normalized offsets
62+
end_offset = offsets[-1]
63+
child_array = child_array.slice(
64+
offset=start_offset.as_py(), length=(end_offset - start_offset).as_py()
65+
)
66+
offsets = pc.subtract(offsets, start_offset)
67+
68+
return pyarrow.ListArray.from_arrays(
69+
offsets=offsets, values=child_array, mask=arr.is_null()
70+
)
71+
72+
73+
@dataclass
74+
class _MapNamespace:
75+
"""Namespace for map operations on expression columns.
76+
77+
This namespace provides methods for operating on map-typed columns
78+
(including MapArrays and ListArrays of Structs) using PyArrow UDFs.
79+
80+
Example:
81+
>>> from ray.data.expressions import col
82+
>>> # Get keys from map column
83+
>>> expr = col("headers").map.keys()
84+
>>> # Get values from map column
85+
>>> expr = col("headers").map.values()
86+
"""
87+
88+
_expr: "Expr"
89+
90+
def keys(self) -> "UDFExpr":
91+
"""Returns a list expression containing the keys of the map.
92+
93+
Example:
94+
>>> from ray.data.expressions import col
95+
>>> # Get keys from map column
96+
>>> expr = col("headers").map.keys()
97+
98+
Returns:
99+
A list expression containing the keys.
100+
"""
101+
return self._create_projection_udf(MapComponent.KEYS)
102+
103+
def values(self) -> "UDFExpr":
104+
"""Returns a list expression containing the values of the map.
105+
106+
Example:
107+
>>> from ray.data.expressions import col
108+
>>> # Get values from map column
109+
>>> expr = col("headers").map.values()
110+
111+
Returns:
112+
A list expression containing the values.
113+
"""
114+
return self._create_projection_udf(MapComponent.VALUES)
115+
116+
def _create_projection_udf(self, component: MapComponent) -> "UDFExpr":
117+
"""Helper to generate UDFs for map projections."""
118+
119+
return_dtype = DataType(object)
120+
if self._expr.data_type.is_arrow_type():
121+
arrow_type = self._expr.data_type.to_arrow_dtype()
122+
123+
is_physical_map = (
124+
pyarrow.types.is_list(arrow_type)
125+
and pyarrow.types.is_struct(arrow_type.value_type)
126+
and arrow_type.value_type.num_fields >= 2
127+
)
128+
129+
inner_arrow_type = None
130+
if pyarrow.types.is_map(arrow_type):
131+
inner_arrow_type = (
132+
arrow_type.key_type
133+
if component == MapComponent.KEYS
134+
else arrow_type.item_type
135+
)
136+
elif is_physical_map:
137+
idx = 0 if component == MapComponent.KEYS else 1
138+
inner_arrow_type = arrow_type.value_type.field(idx).type
139+
140+
if inner_arrow_type:
141+
return_dtype = DataType.list(DataType.from_arrow(inner_arrow_type))
142+
143+
@pyarrow_udf(return_dtype=return_dtype)
144+
def _project_map(arr: pyarrow.Array) -> pyarrow.Array:
145+
return _extract_map_component(arr, component)
146+
147+
return _project_map(self._expr)

python/ray/data/tests/test_namespace_expressions.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,21 @@ def _create_dataset(
4646

4747
# Pytest parameterization for all dataset creation formats
4848
DATASET_FORMATS = ["pandas", "arrow"]
49+
MAP_DATASET_FORMATS = ["arrow"]
50+
51+
52+
def _create_map_dataset(dataset_format: str):
53+
"""Create a dataset backed by an Arrow MapArray column."""
54+
55+
map_items = [
56+
{"attrs": {"color": "red", "size": "M"}},
57+
{"attrs": {"brand": "Ray"}},
58+
]
59+
map_type = pa.map_(pa.string(), pa.string())
60+
arrow_table = pa.table(
61+
{"attrs": pa.array([row["attrs"] for row in map_items], type=map_type)}
62+
)
63+
return _create_dataset(map_items, dataset_format, arrow_table)
4964

5065

5166
# ──────────────────────────────────────
@@ -530,6 +545,112 @@ def test_struct_nested_bracket(self, dataset_format):
530545
assert rows_same(result, expected)
531546

532547

548+
# ──────────────────────────────────────
549+
# Map Namespace Tests
550+
# ──────────────────────────────────────
551+
552+
553+
@pytest.mark.parametrize("dataset_format", MAP_DATASET_FORMATS)
554+
class TestMapNamespace:
555+
"""Tests for map namespace operations."""
556+
557+
def test_map_keys(self, dataset_format):
558+
ds = _create_map_dataset(dataset_format)
559+
560+
result = ds.with_column("keys", col("attrs").map.keys()).to_pandas()
561+
result = result.drop(columns=["attrs"])
562+
563+
expected = pd.DataFrame({"keys": [["color", "size"], ["brand"]]})
564+
assert rows_same(result, expected)
565+
566+
def test_map_values(self, dataset_format):
567+
ds = _create_map_dataset(dataset_format)
568+
569+
result = ds.with_column("values", col("attrs").map.values()).to_pandas()
570+
result = result.drop(columns=["attrs"])
571+
572+
expected = pd.DataFrame({"values": [["red", "M"], ["Ray"]]})
573+
assert rows_same(result, expected)
574+
575+
def test_physical_map_extraction(self, dataset_format):
576+
"""Test extraction works on List<Struct> (Physical Maps)."""
577+
# Construct List<Struct<k, v>>
578+
struct_type = pa.struct([pa.field("k", pa.string()), pa.field("v", pa.int64())])
579+
list_type = pa.list_(struct_type)
580+
581+
data_py = [[{"k": "a", "v": 1}], [{"k": "b", "v": 2}]]
582+
arrow_table = pa.Table.from_arrays(
583+
[pa.array(data_py, type=list_type)], names=["data"]
584+
)
585+
586+
items_data = [{"data": row} for row in data_py]
587+
ds = _create_dataset(items_data, dataset_format, arrow_table)
588+
589+
result = (
590+
ds.with_column("keys", col("data").map.keys())
591+
.with_column("values", col("data").map.values())
592+
.to_pandas()
593+
)
594+
595+
expected = pd.DataFrame(
596+
{
597+
"data": data_py,
598+
"keys": [["a"], ["b"]],
599+
"values": [[1], [2]],
600+
}
601+
)
602+
assert rows_same(result, expected)
603+
604+
def test_map_sliced_offsets(self, dataset_format):
605+
"""Test extraction works correctly on sliced Arrow arrays (offset > 0)."""
606+
items = [{"m": {"id": i}} for i in range(10)]
607+
map_type = pa.map_(pa.string(), pa.int64())
608+
arrays = pa.array([row["m"] for row in items], type=map_type)
609+
table = pa.Table.from_arrays([arrays], names=["m"])
610+
611+
# Force offsets by slicing the table before ingestion
612+
sliced_table = table.slice(offset=7, length=3)
613+
ds = ray.data.from_arrow(sliced_table)
614+
615+
result = ds.with_column("vals", col("m").map.values()).to_pandas()
616+
result = result.drop(columns=["m"])
617+
618+
expected = pd.DataFrame({"vals": [[7], [8], [9]]})
619+
assert rows_same(result, expected)
620+
621+
def test_map_nulls_and_empty(self, dataset_format):
622+
"""Test handling of null maps and empty maps."""
623+
items_data = [{"m": {"a": 1}}, {"m": {}}, {"m": None}]
624+
625+
map_type = pa.map_(pa.string(), pa.int64())
626+
arrays = pa.array([row["m"] for row in items_data], type=map_type)
627+
arrow_table = pa.Table.from_arrays([arrays], names=["m"])
628+
ds = _create_dataset(items_data, dataset_format, arrow_table)
629+
630+
# Use take_all() to avoid pandas casting errors with mixed None/list types
631+
rows = (
632+
ds.with_column("keys", col("m").map.keys())
633+
.with_column("values", col("m").map.values())
634+
.take_all()
635+
)
636+
637+
assert list(rows[0]["keys"]) == ["a"] and list(rows[0]["values"]) == [1]
638+
assert len(rows[1]["keys"]) == 0 and len(rows[1]["values"]) == 0
639+
assert rows[2]["keys"] is None and rows[2]["values"] is None
640+
641+
def test_map_chaining(self, dataset_format):
642+
ds = _create_map_dataset(dataset_format)
643+
644+
# map.keys() returns a list, so .list.len() should apply
645+
result = ds.with_column(
646+
"num_keys", col("attrs").map.keys().list.len()
647+
).to_pandas()
648+
result = result.drop(columns=["attrs"])
649+
650+
expected = pd.DataFrame({"num_keys": [2, 1]})
651+
assert rows_same(result, expected)
652+
653+
533654
# ──────────────────────────────────────
534655
# Datetime Namespace Tests
535656
# ──────────────────────────────────────

0 commit comments

Comments
 (0)