forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_parquet_schema_inference.py
More file actions
68 lines (55 loc) · 2.41 KB
/
Copy pathtest_parquet_schema_inference.py
File metadata and controls
68 lines (55 loc) · 2.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import pytest
import pyarrow as pa
import pyarrow.dataset as pds
import pyarrow.fs as pafs
import pyarrow.parquet as pq
from packaging.version import parse as parse_version
from ray._common.utils import get_pyarrow_version
from ray.data._internal.datasource.parquet_datasource import _infer_schema
def test_read_parquet_memory_growth(tmp_path, monkeypatch):
"""Schema inference should not inspect every fragment on PyArrow >= 22.
Regression test for a bug where _infer_schema fell back to reading every
fragment's physical_schema when the sampled fragment had a pa.null() column
(PyArrow < 22.0), causing O(N) metadata reads and memory usage.
"""
if get_pyarrow_version() < parse_version("22.0.0"):
pytest.skip("Bounded permissive schema inspection requires PyArrow >= 22.0.0")
num_cols = 50
num_files = 1000
inspect_num_fragments = 1
def _write_files(directory, n_files):
directory.mkdir(exist_ok=True)
for i in range(n_files):
cols = {f"col_{j}": [0] for j in range(num_cols)}
# First file has a column of all nulls, which triggers the schema inference fallback.
if i == 0:
cols["null_col"] = pa.nulls(1)
else:
cols["null_col"] = [1]
pq.write_table(pa.table(cols), directory / f"part_{i:05d}.parquet")
_write_files(tmp_path, num_files)
inspect_calls = []
real_factory = pds.FileSystemDatasetFactory
# RSS deltas for this code path are sub-MiB in CI, so check the bounded
# schema-inspection behavior directly instead of comparing process memory.
class TrackingFactory:
def __init__(self, *args, **kwargs):
self._factory = real_factory(*args, **kwargs)
def inspect(self, **kwargs):
inspect_calls.append(kwargs)
return self._factory.inspect(**kwargs)
def finish(self, *args, **kwargs):
pytest.fail("Schema inference should not inspect every fragment")
monkeypatch.setattr(pds, "FileSystemDatasetFactory", TrackingFactory)
schema = _infer_schema(
[str(path) for path in sorted(tmp_path.iterdir())],
inspect_num_fragments=inspect_num_fragments,
filesystem=pafs.LocalFileSystem(),
)
assert inspect_calls == [
{
"fragments": inspect_num_fragments,
"promote_options": "permissive",
}
]
assert "null_col" in schema.names