|
11 | 11 | from dask._task_spec import Task |
12 | 12 | from dask.typing import Key |
13 | 13 | from dask.utils import funcname, parse_bytes |
| 14 | +from fsspec.utils import read_block |
14 | 15 |
|
15 | 16 | from dask_expr._expr import Index, Projection, determine_column_projection |
16 | 17 | from dask_expr._util import _convert_to_list, _tokenize_deterministic |
@@ -208,7 +209,7 @@ def read_fragments( |
208 | 209 | promote_options="permissive", |
209 | 210 | ) |
210 | 211 | if len(tables) > 1 |
211 | | - else tables, |
| 212 | + else tables[0], |
212 | 213 | **table_to_dataframe_options, |
213 | 214 | ) |
214 | 215 |
|
@@ -329,10 +330,107 @@ def _simplify_up(self, parent, dependents): |
329 | 330 |
|
330 | 331 | def _simplify_down(self): |
331 | 332 | file_format = self.dataset.format.default_extname |
332 | | - if file_format == "parquet": |
| 333 | + if file_format == "csv": |
| 334 | + return FromArrowDatasetCSV(*self.operands) |
| 335 | + elif file_format == "json": |
| 336 | + return FromArrowDatasetJSON(*self.operands) |
| 337 | + elif file_format == "parquet": |
333 | 338 | return FromArrowDatasetParquet(*self.operands) |
334 | 339 |
|
335 | 340 |
|
| 341 | +class FromArrowDatasetCSV(FromArrowDataset): |
| 342 | + @classmethod |
| 343 | + def _partial_fragment_to_table( |
| 344 | + cls, |
| 345 | + fragment, |
| 346 | + schema, |
| 347 | + filters, |
| 348 | + split_index, |
| 349 | + split_count, |
| 350 | + options, |
| 351 | + ): |
| 352 | + # Calculate byte range for this read |
| 353 | + path = fragment.path |
| 354 | + filesystem = fragment.filesystem |
| 355 | + size = filesystem.get_file_info(path).size |
| 356 | + nbytes = size // split_count |
| 357 | + offset = nbytes * split_index |
| 358 | + if split_index == (split_count - 1): |
| 359 | + nbytes = size - offset |
| 360 | + |
| 361 | + # Handle header and delimiter |
| 362 | + add_header = b"" |
| 363 | + row_delimiter = b"\n" |
| 364 | + scan_options = fragment.format.default_fragment_scan_options |
| 365 | + column_names = scan_options.column_names |
| 366 | + skip_rows = scan_options.skip_rows |
| 367 | + if split_index: |
| 368 | + if not column_names and not skip_rows: |
| 369 | + add_header = _read_byte_block( |
| 370 | + path, |
| 371 | + filesystem, |
| 372 | + 0, |
| 373 | + 1, |
| 374 | + delimiter=row_delimiter, |
| 375 | + ) |
| 376 | + for _ in range(skip_rows): |
| 377 | + add_header += row_delimiter |
| 378 | + |
| 379 | + # Read partial fragment |
| 380 | + return fragment.format.make_fragment( |
| 381 | + pa.py_buffer( |
| 382 | + add_header |
| 383 | + + _read_byte_block( |
| 384 | + path, |
| 385 | + filesystem, |
| 386 | + offset, |
| 387 | + nbytes, |
| 388 | + delimiter=row_delimiter, |
| 389 | + ) |
| 390 | + ) |
| 391 | + ).to_table( |
| 392 | + filter=filters, |
| 393 | + **options, |
| 394 | + ) |
| 395 | + |
| 396 | + |
| 397 | +class FromArrowDatasetJSON(FromArrowDataset): |
| 398 | + @classmethod |
| 399 | + def _partial_fragment_to_table( |
| 400 | + cls, |
| 401 | + fragment, |
| 402 | + schema, |
| 403 | + filters, |
| 404 | + split_index, |
| 405 | + split_count, |
| 406 | + options, |
| 407 | + ): |
| 408 | + # Calculate byte range for this read |
| 409 | + path = fragment.path |
| 410 | + filesystem = fragment.filesystem |
| 411 | + size = filesystem.get_file_info(path).size |
| 412 | + nbytes = size // split_count |
| 413 | + offset = nbytes * split_index |
| 414 | + if split_index == (split_count - 1): |
| 415 | + nbytes = size - offset |
| 416 | + |
| 417 | + # Read partial fragment |
| 418 | + return fragment.format.make_fragment( |
| 419 | + pa.py_buffer( |
| 420 | + _read_byte_block( |
| 421 | + path, |
| 422 | + filesystem, |
| 423 | + offset, |
| 424 | + nbytes, |
| 425 | + delimiter=b"\n", |
| 426 | + ) |
| 427 | + ) |
| 428 | + ).to_table( |
| 429 | + filter=filters, |
| 430 | + **options, |
| 431 | + ) |
| 432 | + |
| 433 | + |
336 | 434 | class FromArrowDatasetParquet(FromArrowDataset): |
337 | 435 | _scan_options = pa.dataset.ParquetFragmentScanOptions( |
338 | 436 | pre_buffer=True, |
@@ -384,3 +482,21 @@ def _partial_fragment_to_table( |
384 | 482 | filter=filters, |
385 | 483 | **options, |
386 | 484 | ) |
| 485 | + |
| 486 | + |
| 487 | +def _read_byte_block( |
| 488 | + path, |
| 489 | + filesystem, |
| 490 | + offset, |
| 491 | + nbytes, |
| 492 | + delimiter=None, |
| 493 | +): |
| 494 | + # Use fsspec to read in a delimited byte range |
| 495 | + with filesystem.open_input_file(path) as f: |
| 496 | + block = read_block( |
| 497 | + f, |
| 498 | + offset, |
| 499 | + nbytes, |
| 500 | + delimiter, |
| 501 | + ) |
| 502 | + return block |
0 commit comments