diff --git a/deltacat/examples/experimental/rivulet/pytorch_demo.ipynb b/deltacat/examples/experimental/rivulet/pytorch_demo.ipynb index f989c7e05..06cc75e52 100644 --- a/deltacat/examples/experimental/rivulet/pytorch_demo.ipynb +++ b/deltacat/examples/experimental/rivulet/pytorch_demo.ipynb @@ -1,8 +1,9 @@ { "cells": [ { - "metadata": {}, "cell_type": "markdown", + "id": "2fb18b4d46a9548", + "metadata": {}, "source": [ "# PyTorch Demo: Sentiment Analysis and Question Detection with Rivulet Dataset\n", "\n", @@ -13,14 +14,16 @@ "- **Pytorch Integration:** Easily allows passing of data between pytorch models and transformers.\n", "- **Non-Destructive Transformation:** Transforms the data (e.g., adding sentiment and question classification) without modifying the original dataset.\n", "- **Exporting Data:** Exports the modified dataset to supported formats such as Parquet and JSON for further analysis." - ], - "id": "2fb18b4d46a9548" + ] }, { + "cell_type": "code", + "execution_count": null, + "id": "initial_id", "metadata": { "collapsed": true }, - "cell_type": "code", + "outputs": [], "source": [ "import torch\n", "from typing import List\n", @@ -29,14 +32,14 @@ "import pathlib\n", "import pyarrow as pa\n", "import pyarrow.csv as csv" - ], - "id": "initial_id", - "outputs": [], - "execution_count": null + ] }, { - "metadata": {}, "cell_type": "code", + "execution_count": null, + "id": "51a2ddaed83da5f3", + "metadata": {}, + "outputs": [], "source": [ "# Load tokenizer and model for sentiment analysis\n", "sentiment_tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased-finetuned-sst-2-english\")\n", @@ -46,14 +49,14 @@ "question_tokenizer = AutoTokenizer.from_pretrained(\"shahrukhx01/question-vs-statement-classifier\")\n", "question_model = AutoModelForSequenceClassification.from_pretrained(\"shahrukhx01/question-vs-statement-classifier\")\n", "question_model.eval()" - ], - "id": "51a2ddaed83da5f3", - "outputs": [], - "execution_count": null + ] }, { - "metadata": {}, "cell_type": "code", + "execution_count": null, + "id": "b74792a57b9b28c1", + "metadata": {}, + "outputs": [], "source": [ "# Create a rivulet dataset using the CSV file\n", "cwd = pathlib.Path.cwd()\n", @@ -65,14 +68,14 @@ " merge_keys=\"msg_id\"\n", ")\n", "ds.print(num_records=10)" - ], - "id": "b74792a57b9b28c1", - "outputs": [], - "execution_count": null + ] }, { - "metadata": {}, "cell_type": "code", + "execution_count": null, + "id": "1b90411fd69378e9", + "metadata": {}, + "outputs": [], "source": [ "# define a new schema with fields for pytorch classification\n", "ds.add_fields([\n", @@ -80,14 +83,14 @@ " (\"sentiment\", dc.Datatype.float()),\n", " (\"is_question\", dc.Datatype.float())\n", "], schema_name=\"message_classifier\", merge_keys=[\"msg_id\"])" - ], - "id": "1b90411fd69378e9", - "outputs": [], - "execution_count": null + ] }, { - "metadata": {}, "cell_type": "code", + "execution_count": null, + "id": "587f17e09e5d306a", + "metadata": {}, + "outputs": [], "source": [ "# compute classification values and update records in dataset\n", "def compute_sentiments(batch: pa.RecordBatch) -> List[float]:\n", @@ -134,21 +137,18 @@ "\n", "dataset_writer.flush()\n", "print(\"Sentiment and is_question values have been computed and updated in the dataset.\")" - ], - "id": "587f17e09e5d306a", - "outputs": [], - "execution_count": null + ] }, { - "metadata": {}, "cell_type": "code", + "execution_count": null, + "id": "8ef2dd2a1bc4e66a", + "metadata": {}, + "outputs": [], "source": [ "# export to a supported format (JSON, PARQUET, FEATHER)\n", "ds.export(file_uri=\"./output.json\", format=\"json\")" - ], - "id": "8ef2dd2a1bc4e66a", - "outputs": [], - "execution_count": null + ] } ], "metadata": { diff --git a/deltacat/examples/experimental/rivulet/wds_demo.py b/deltacat/examples/experimental/rivulet/wds_demo.py new file mode 100644 index 000000000..682dc08e3 --- /dev/null +++ b/deltacat/examples/experimental/rivulet/wds_demo.py @@ -0,0 +1,101 @@ +import torch +from deltacat.storage.rivulet import Dataset +import pyarrow as pa +from typing import List +from PIL import Image +import io +from deltacat.storage.rivulet.schema.schema import Datatype +from transformers import AutoImageProcessor, AutoModelForImageClassification + + +# tar_path = "deltacat/tests/test_utils/resources/imagenet1k-train-0000.tar" +tar_path = "deltacat/tests/test_utils/resources/nestedjson.tar" + +# Load the dataset from the tar file +ds = Dataset.from_webdataset( + name="bird_species_test", # Name of the dataset + file_uri=tar_path, # Location of the tar file + merge_keys="filename", # Merge batches using the 'filename' key +) + +# Print the available fields in the dataset +print(ds.fields) + +# Load the image processor and classification model from HuggingFace +processor = AutoImageProcessor.from_pretrained("chriamue/bird-species-classifier") +model = AutoModelForImageClassification.from_pretrained( + "chriamue/bird-species-classifier" +) +model.eval() + + +# Function to classify bird species from a record batch +def compute_bird_species(batch: pa.RecordBatch) -> List[str]: + # Extract the binary image column + image_column = batch.column("image_binary").to_pylist() + + # Initialize list to store PIL Image objects + pil_images = [] + for img_binary in image_column: + try: + # Convert binary data to image and convert to RGB + img = Image.open(io.BytesIO(img_binary)).convert("RGB") + pil_images.append(img) + except Exception as e: + # Print error if image decoding fails + print(f"Error reading image: {e}") + + # If images were successfully decoded + if pil_images: + # Preprocess images and run them through the model + inputs = processor(images=pil_images, return_tensors="pt") + with torch.no_grad(): # Disable gradient computation + outputs = model(**inputs) + + # Get the predicted label indices + predicted_ids = torch.argmax(outputs.logits, dim=1).tolist() + + # Map indices to human-readable class labels + predicted_labels = [model.config.id2label[idx] for idx in predicted_ids] + + return predicted_labels + else: + # Return empty list if no images were valid + return [] + + +# Add new fields to the dataset: filename and predicted bird species +ds.add_fields( + [ + ("filename", Datatype.string()), # String type for filename + ("bird_species", Datatype.string()), # String type for predicted label + ], + schema_name="bird_species_classifier", + merge_keys=["filename"], +) # Schema name and merge key + +# Initialize writer to store output under the new schema +dataset_writer = ds.writer(schema_name="bird_species_classifier") + +# Iterate over each Arrow batch in the dataset +for batch in ds.scan().to_arrow(): + print(batch) # Print the batch contents + filenames = batch.column("filename").to_pylist() # Extract filenames + bird_labels = compute_bird_species(batch) # Run classification on batch + + rows_to_write = [] # Prepare rows to be written + if bird_labels: + # Create a list of dictionaries combining filename and predicted species + rows_to_write = [ + {"filename": fname, "bird_species": bird_species} + for fname, bird_species in zip(filenames, bird_labels) + ] + print("ROWS", rows_to_write) # Print rows to be written + dataset_writer.write(rows_to_write) # Write the output to dataset + +dataset_writer.flush() + +# Export the results to a local JSON file +ds.export(file_uri="./bird_classification_species_predictions.json", format="json") + +print("Bird species classification complete.") diff --git a/deltacat/experimental/storage/rivulet/dataset.py b/deltacat/experimental/storage/rivulet/dataset.py index 453d7a239..c16fc7ebb 100644 --- a/deltacat/experimental/storage/rivulet/dataset.py +++ b/deltacat/experimental/storage/rivulet/dataset.py @@ -38,7 +38,9 @@ from deltacat.experimental.storage.rivulet.reader.query_expression import ( QueryExpression, ) - +from deltacat.experimental.storage.rivulet.reader.webdataset_reader import ( + WebDatasetReader, +) from deltacat.experimental.storage.rivulet.writer.dataset_writer import DatasetWriter from deltacat.experimental.storage.rivulet.writer.memtable_dataset_writer import ( MemtableDatasetWriter, @@ -479,6 +481,77 @@ def from_json( return dataset + @classmethod + def from_webdataset( + cls, + name: str, + file_uri: str, + merge_keys: str | Iterable[str] = None, + metadata_uri: Optional[str] = None, + schema_mode: str = "union", + batch_size: Optional[int] = 1, + filesystem: Optional[pyarrow.fs.FileSystem] = None, + namespace: str = DEFAULT_NAMESPACE, + ) -> "Dataset": + """ + Create a Dataset from a single webdataset tar file. + + TODO: Add support for reading directories with multiple WDS files. + + Args: + name: Unique identifier for the dataset. + metadata_uri: Base URI for the dataset, where dataset metadata is stored. If not specified, will be placed in ${file_uri}/riv-meta + file_uri: Path to a single webdataset file. + merge_keys: Fields to specify as merge keys for future 'zipper merge' operations on the dataset. + schema_mode: Currently ignored as this is for a single file. + + Returns: + Dataset: New dataset instance with the schema automatically inferred + from the tar file. + """ + # TODO: integrate this with filesystem from deltacat catalog + file_uri, file_fs = FileStore.filesystem(file_uri, filesystem=filesystem) + if metadata_uri is None: + metadata_uri = posixpath.join(posixpath.dirname(file_uri), "riv-meta") + else: + metadata_uri, metadata_fs = FileStore.filesystem( + metadata_uri, filesystem=filesystem + ) + + # TODO: when integrating deltacat consider if we can support multiple filesystems + if file_fs.type_name != metadata_fs.type_name: + raise ValueError( + "File URI and metadata URI must be on the same filesystem." + ) + + # Read the WebDataset into a PyArrow Table + wds_parser = WebDatasetReader( + name=name, + file_uri=file_uri, + merge_keys=merge_keys, + schema_mode=schema_mode, + batch_size=batch_size, + namespace=namespace, + ) + pyarrow_table = wds_parser.to_pyarrow() + # Create the Dataset and write to it + dataset_schema = Schema.from_pyarrow( + pyarrow_table.schema, merge_keys=merge_keys + ) + + dataset = cls( + dataset_name=name, + metadata_uri=metadata_uri, + schema=dataset_schema, + filesystem=file_fs, + namespace=namespace, + ) + + writer = dataset.writer() + writer.write(pyarrow_table.to_batches()) + writer.flush() + return dataset + @classmethod def from_csv( cls, @@ -522,7 +595,7 @@ def from_csv( ) # Read the CSV file into a PyArrow Table - table = pyarrow.csv.read_csv(file_uri, filesystem=file_fs) + table = pyarrow.csv.read_csv(file_uri) pyarrow_schema = table.schema # Create the dataset schema @@ -718,7 +791,6 @@ def writer( :return: new dataset writer with a schema at the conjunction of the given schemas """ schema_name = schema_name or ALL - return MemtableDatasetWriter( self._file_provider, self.schemas[schema_name], self._locator, file_format ) diff --git a/deltacat/experimental/storage/rivulet/reader/webdataset_reader.py b/deltacat/experimental/storage/rivulet/reader/webdataset_reader.py new file mode 100644 index 000000000..9de7e9f05 --- /dev/null +++ b/deltacat/experimental/storage/rivulet/reader/webdataset_reader.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import logging + +import os +import tarfile +from typing import Optional, Iterable + +import pyarrow as pa +import pyarrow.json +from deltacat.constants import ( + DEFAULT_NAMESPACE, +) + +from deltacat import logs + +logger = logs.configure_deltacat_logger(logging.getLogger(__name__)) + + +class WebDatasetReader: + def __init__( + self, + name: str, + file_uri: str, + merge_keys: str | Iterable[str] = None, + schema_mode: str = "union", + batch_size: Optional[int] = 1, + namespace: str = DEFAULT_NAMESPACE, + ): + merge_key = self._validate_single_merge_key(merge_keys) + + self.name = name + self.file_uri = file_uri + self.merge_key = merge_key + self.schema_mode = schema_mode + self.batch_size = batch_size + self.namespace = namespace + + def _validate_single_merge_key(self, merge_keys): + """Checks that there is only one merge key, and returns this merge key (Iterable or String).""" + if not merge_keys or not isinstance(merge_keys, str): + if len(merge_keys) == 1: + return merge_keys[0] + else: + raise ValueError( + "Multiple merge keys are not supported in from_webdataset(). Please specify only 1 merge key as a string." + ) + return merge_keys + + def _align_to_schema(self, tbl: pa.Table, schema: pa.Schema) -> pa.Table: + # add missing fields + for f in schema: + if f.name not in tbl.column_names: + tbl = tbl.append_column(f.name, pa.nulls(1, type=f.type)) + # reorder to match schema field order + rebuild table in the correct order + tbl = pa.table([tbl[f.name] for f in schema], schema=schema) + return tbl + + def to_pyarrow(self): + """Returns a pyarrow table of the tar members, storing file content for each member in the webdataset in a new media_binary column.""" + tables = [] + media_binaries = [] + with tarfile.open(self.file_uri, "r") as tar: + tar_members = tar.getmembers() + tar_members = [ + member + for member in tar_members + if member.isfile() and not member.name.startswith("._") + ] + # Get individual pyarrow tables (1 per json member) and media_binaries + for i in range(0, len(tar_members), 2): + # Get the json member and corresponding media member by index + # With webdatasets: guaranteed that the json and its corresponding media file are next to each other + # However the order of media file first or json first is not specified + member, media_member = tar_members[i], tar_members[i + 1] + if not member.name.endswith(".json"): + member, media_member = media_member, member + try: + f = tar.extractfile(member) + if not f: + continue + tbl = pyarrow.json.read_json(f) + # Validate that the member actually corresponds with the media_member + media_filename = os.path.basename(tbl[self.merge_key][0].as_py()) + key, _ = member.name.split(".", 1) + expected_media_name = f"{key}.{media_filename}" + if expected_media_name != media_member.name: + logger.warning( + "Mismatched filename for sample %s: expected media %s but found %s", + key, + expected_media_name, + media_member.name, + ) + continue + f_media = tar.extractfile(media_member) + if not f_media: + continue + media_binary = f_media.read() + # Add media binary and table to respective lists + media_binaries.extend([media_binary]) + tables.append(tbl) + except Exception as e: + logger.warning("Error processing member %s: %s", member.name, e) + if len(tables) != len(media_binaries): + logger.error( + "Mismatch between number of JSON tables (%d) and media binaries (%d)", + len(tables), + len(media_binaries), + ) + return + + unified = pa.unify_schemas( + [t.schema for t in tables], promote_options="permissive" + ) + tables = [self._align_to_schema(t, unified) for t in tables] + final_table = pa.concat_tables(tables, unify_schemas=False, promote=False) + media_binary_array = pa.array(media_binaries, type=pa.binary()) + final_table = final_table.append_column("media_binary", media_binary_array) + return final_table diff --git a/deltacat/experimental/storage/rivulet/schema/schema.py b/deltacat/experimental/storage/rivulet/schema/schema.py index 4214c12a2..8ffec3e38 100644 --- a/deltacat/experimental/storage/rivulet/schema/schema.py +++ b/deltacat/experimental/storage/rivulet/schema/schema.py @@ -117,6 +117,7 @@ def from_pyarrow( Raises: ValueError: If key is not found in schema """ + merge_keys = [] if merge_keys is None else merge_keys merge_keys = [merge_keys] if isinstance(merge_keys, str) else merge_keys fields = {} diff --git a/deltacat/tests/experimental/storage/rivulet/conftest.py b/deltacat/tests/experimental/storage/rivulet/conftest.py index 8d7cdedb8..4bf3f12a9 100644 --- a/deltacat/tests/experimental/storage/rivulet/conftest.py +++ b/deltacat/tests/experimental/storage/rivulet/conftest.py @@ -10,6 +10,8 @@ import string from PIL import Image +# import webdataset as wds + FIXTURE_ROW_COUNT = 10000 diff --git a/deltacat/tests/experimental/storage/rivulet/schema/test_wds.py b/deltacat/tests/experimental/storage/rivulet/schema/test_wds.py new file mode 100644 index 000000000..0ed1ef9b8 --- /dev/null +++ b/deltacat/tests/experimental/storage/rivulet/schema/test_wds.py @@ -0,0 +1,333 @@ +import os +import pytest +import shutil +import json + +import tempfile +from pathlib import Path +import webdataset as wds + +from deltacat.experimental.storage.rivulet.dataset import Dataset +from deltacat.experimental.storage.rivulet import Field, Datatype + + +@pytest.fixture(scope="class") +def temp_dir(tmp_path_factory) -> Path: + # One directory for the whole class + return tmp_path_factory.mktemp("rivulet_suite") + + +def _add_txt_files_and_wds_tar(files, base_dir, name): + for _, content in files.items(): + rel_path = content["filename"] + full_path = os.path.join(base_dir, rel_path) + + # Ensure the directory exists + os.makedirs(os.path.dirname(full_path), exist_ok=True) + + with open(full_path, "w") as f: + f.write("Test .txt content.") + # write shard with WebDataset + shard_path_pattern = str(base_dir / f"{name}-%06d.tar") + with wds.ShardWriter(shard_path_pattern, maxcount=1000) as sink: + for i, (json_name, content) in enumerate(files.items()): + txt_path = os.path.join(base_dir, content["filename"]) + with open(txt_path, "rb") as f: + txt_bytes = f.read() + + sample = { + "__key__": f"{i:06d}", # still needed for WDS, but won't prefix files + os.path.basename(content["filename"]): txt_bytes, + json_name: json.dumps(content).encode("utf-8"), + } + sink.write(sample) + + return [str(p) for p in Path(base_dir).glob(f"{name}-*.tar")] + + +@pytest.fixture +def sample_wds_simple(temp_dir: Path): + """Create a simple WebDataset shard using webdataset.ShardWriter.""" + name = "simple" + files = { + f"{name}_first.json": { + "label": 1, + "width": 500, + "height": 429, + "filename": "n01443537/n01443537_14753.TXT", + "extra": 101, + }, + f"{name}_second.json": { + "label": 2, + "width": 200, + "height": 300, + "filename": "n01443538/n01443538_14754.TXT", + "extra": 102, + }, + } + + # create corresponding dummy .txt files on disk + return _add_txt_files_and_wds_tar(files, temp_dir, name) + + +@pytest.fixture +def sample_wds_simple_2(temp_dir): + """Create a simple WebDataset shard using webdataset.ShardWriter.""" + name = "simple_2" + files = { + f"{name}_first.json": { + "label": 1, + "width": 500, + "height": 429, + "filename": "n01443537/n01443537_14753.TXT", + "extra": 101, + }, + f"{name}_second.json": { + "label": 2, + "width": 200, + "height": 300, + "filename": "n01443538/n01443538_14754.TXT", + "extra": 102, + }, + } + # create corresponding dummy .txt files on disk + return _add_txt_files_and_wds_tar(files, temp_dir, name) + + +@pytest.fixture +def sample_wds_long(temp_dir): + name = "long" + files = {} + for i in range(6): + files[f"long_{i}.json"] = { + "label": i, + "width": 100 + i * 50, + "height": 200 + i * 50, + "filename": f"n0144353{i}/n0144353{i}_1475{i}.TXT", + "extra": 100 + i, + } + return _add_txt_files_and_wds_tar(files, temp_dir, name) + + +@pytest.fixture +def sample_wds_inconsistent(temp_dir): + name = "inconsistent" + files = { + f"{name}_first.json": { + "label": 1, + "width": 500, + "height": 429, + "filename": "n01443537/n01443537_14753.TXT", + "extra": 101, + }, + f"{name}_second.json": { + "label": 2, + "width": 200, + "height": 300, + "filename": "n01443538/n01443538_14754.TXT", + }, + f"{name}_third.json": { + "label": 3, + "width": 200, + "height": 300, + "filename": "n01443539/n01443538_14755.TXT", + "extra": 103, + "extra_3": 333, + }, + f"{name}_fourth.json": { + "label": 4, + "width": 200, + "height": 300, + "filename": "n01443540/n01443538_14756.TXT", + "extra_4": 444, + }, + f"{name}_fifth.json": { + "diff_label": 5, + "diff_width": 500, + "diff_height": 600, + "filename": "n01443541/n01443538_14757.TXT", + }, + } + return _add_txt_files_and_wds_tar(files, temp_dir, name) + + +@pytest.fixture +def sample_wds_diff_data_types(temp_dir: Path): + """Create a simple WebDataset shard with different data types under the same column name.""" + name = "test_conflicting_data_types" + files = { + f"{name}_first.json": { + "label": 1, + "width": 500, + "height": 429, + "filename": "n01443537/n01443537_14753.TXT", + }, + f"{name}_second.json": { + "label": 2, + "width": 200, + "height": 300.5, + "filename": "n01443538/n01443538_14754.TXT", + }, + } + + # create corresponding dummy .txt files on disk + return _add_txt_files_and_wds_tar(files, temp_dir, name) + + +class TestFromWebDataset: + @classmethod + def setup_class(cls): + cls.temp_dir = tempfile.mkdtemp() + + @classmethod + def teardown_class(cls): + shutil.rmtree(cls.temp_dir) + + def test_consistent_schema_handling(self, temp_dir, sample_wds_simple): + """Test that from_webdataset correctly creates a dataset from a WebDataset with consistent JSON schemas and one merge key.""" + dataset = Dataset.from_webdataset( + name="test_dataset", + file_uri=sample_wds_simple, + metadata_uri=temp_dir, + merge_keys="filename", + ) + + # Verify schema fields + assert len(dataset.fields) == 6 + assert "label" in dataset.fields + assert "width" in dataset.fields + assert "height" in dataset.fields + assert "filename" in dataset.fields + assert "extra" in dataset.fields + assert "media_binary" in dataset.fields + + assert dataset.fields["filename"].is_merge_key + + # Verify datatypes inferred correctly + assert dataset.fields["label"].datatype == Datatype.int64() + assert dataset.fields["width"].datatype == Datatype.int64() + assert dataset.fields["height"].datatype == Datatype.int64() + assert dataset.fields["filename"].datatype == Datatype.string() + assert dataset.fields["extra"].datatype == Datatype.int64() + + # Verify data can be read + records = list(dataset.scan().to_pydict()) + + # Check first record + first_record = records[0] + assert first_record["label"] == 1 + assert first_record["width"] == 500 + assert first_record["height"] == 429 + assert first_record["filename"] == "n01443537/n01443537_14753.TXT" + assert "media_binary" in first_record + assert isinstance(first_record["media_binary"], bytes) + assert len(first_record["media_binary"]) > 0 + + # Verify all fields are Field objects + for _, field in dataset.fields: + assert isinstance(field, Field) + assert hasattr(field, "name") + assert hasattr(field, "datatype") + assert hasattr(field, "is_merge_key") + + # Verify media_binary field exists + assert "media_binary" in dataset.fields + assert dataset.fields["media_binary"].datatype == Datatype.binary("binary") + + def test_inconsistent_schema_handling(self, temp_dir, sample_wds_inconsistent): + """Test that from_webdataset correctly handles inconsistent JSON schemas.""" + dataset = Dataset.from_webdataset( + name="test_dataset", + file_uri=sample_wds_inconsistent, + metadata_uri=temp_dir, + merge_keys="filename", + ) + + # Should include all fields from both schemas + assert len(dataset.fields) == 11 + assert "label" in dataset.fields + assert "width" in dataset.fields + assert "height" in dataset.fields + assert "filename" in dataset.fields + assert "extra" in dataset.fields + assert "extra_3" in dataset.fields + assert "extra_4" in dataset.fields + assert "diff_label" in dataset.fields + assert "diff_width" in dataset.fields + assert "diff_height" in dataset.fields + assert "media_binary" in dataset.fields + + records = list(dataset.scan().to_pydict()) + assert len(records) == 5 + + def test_multiple_merge_key_handling(self, temp_dir, sample_wds_simple): + """Test that specifying more than 1 merge key raises an error.""" + with pytest.raises(ValueError): + Dataset.from_webdataset( + name="test_multiple_merge_key_handling", + file_uri=sample_wds_simple, + metadata_uri=temp_dir, + merge_keys=["label", "filename"], + ) + + def test_conflicting_data_types(self, temp_dir, sample_wds_diff_data_types): + """Test that specifying different datatypes under the same column is handled properly for non-lossy promotions.""" + dataset = Dataset.from_webdataset( + name="test_conflicting_data_types", + file_uri=sample_wds_diff_data_types, + metadata_uri=temp_dir, + merge_keys=["filename"], + ) + assert "label" in dataset.fields + assert "width" in dataset.fields + assert "height" in dataset.fields + assert dataset.fields["label"].datatype == Datatype.int64() + assert dataset.fields["width"].datatype == Datatype.int64() + assert dataset.fields["height"].datatype == Datatype.float() + + def test_invalid_merge_key_handling(self, temp_dir, sample_wds_simple): + """Test that specifying a non-existent field as merge key raises an error.""" + with pytest.raises(ValueError): + Dataset.from_webdataset( + name="test_invalid_merge_key_handling", + file_uri=sample_wds_simple, + metadata_uri=temp_dir, + merge_keys="nonexistent_field", + ) + + def test_batch_reading_functionality(self, temp_dir, sample_wds_long): + """Test that batch reading works correctly with different batch sizes.""" + # Test with batch_size=1 + dataset1 = Dataset.from_webdataset( + name="test_batch1", + file_uri=sample_wds_long, + metadata_uri=temp_dir, + merge_keys="filename", + batch_size=1, + ) + + # Test with batch_size=2 + dataset2 = Dataset.from_webdataset( + name="test_batch2", + file_uri=sample_wds_long, + metadata_uri=temp_dir, + merge_keys="filename", + batch_size=2, + ) + + # Test with batch_size=3 + dataset3 = Dataset.from_webdataset( + name="test_batch3", + file_uri=sample_wds_long, + metadata_uri=temp_dir, + merge_keys="filename", + batch_size=3, + ) + + # Both should produce the same data + records1 = list(dataset1.scan().to_pydict()) + records2 = list(dataset2.scan().to_pydict()) + records3 = list(dataset3.scan().to_pydict()) + + assert len(records1) == len(records2) == len(records3) == 6 + assert records1 == records2 == records3 diff --git a/test_repeat.sh b/test_repeat.sh new file mode 100755 index 000000000..cd4daddd9 --- /dev/null +++ b/test_repeat.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +# pytest deltacat/tests/compute/test_janitor.py -s + +# 1️⃣ Set how many times to repeat each test run +n=20 # <-- change this number as needed + +# 2️⃣ Counters for failures +fail_no_s=0 +fail_with_s=0 + +echo "Running WITHOUT -s (no print output)..." +for i in $(seq 1 $n); do + echo "Run $i/$n (WITHOUT -s)" + pytest deltacat/tests/storage/model/test_metafile_io.py::TestMetafileIO::test_txn_conflict_concurrent_multiprocess_table_create > run_no_s.log 2>&1 + if [ $? -ne 0 ]; then + fail_no_s=$((fail_no_s+1)) + echo "❌ Failed (WITHOUT -s)" + else + echo "✅ Passed" + fi +done + +echo "" +echo "Running WITH -s (print output enabled)..." +for i in $(seq 1 $n); do + echo "Run $i/$n (WITH -s)" + pytest -s deltacat/tests/storage/model/test_metafile_io.py::TestMetafileIO::test_txn_conflict_concurrent_multiprocess_table_create > run_with_s.log 2>&1 + if [ $? -ne 0 ]; then + fail_with_s=$((fail_with_s+1)) + echo "❌ Failed (WITH -s)" + else + echo "✅ Passed" + fi +done + +# 3️⃣ Final summary +echo "" +echo "results WITHOUT print output enabled: $fail_no_s/$n runs failed." +echo "results WITH print output enabled: $fail_with_s/$n runs failed."