Skip to content

[BUG] Pipeline serialization/caching issue when including RoutingBatchFunction #1070

@liamcripwell

Description

@liamcripwell

Describe the bug
I am experiencing an issue when trying to use a RoutingBatchFunction inside a pipeline. Specifically I am using sample_n_steps() as shown as an example here: https://distilabel.argilla.io/latest/api/pipeline/routing_batch_function/?h=routing#distilabel.pipeline.routing_batch_function.routing_batch_function

The pipeline initially runs without issue, but if I try to run it again it gives AttributeError: 'NoneType' object has no attribute 'name' stemming from distilabel/src/distilabel/pipeline/routing_batch_function.py:170 in dump (on develop). This seems to be a failure with serializing this step.

We are able to get around this temporarily by manually deleting the cache directory on disk, but the error continues to occur even when using use_cache=False in pipeline.run(). Is caching supposed to be required to some degree even when this is specified?

To Reproduce
Code to reproduce

from distilabel.llms import OpenAILLM
from distilabel.pipeline import Pipeline, sample_n_steps
from distilabel.steps import LoadDataFromHub, GroupColumns
from distilabel.steps.tasks import TextGeneration

random_routing_batch = sample_n_steps(2)


with Pipeline(name="routing-batch-function") as pipeline:
    load_dataset = LoadDataFromHub(
        name="load_dataset",
    )

    generations = []
    for llm in (
        OpenAILLM(model="gpt-4o"),
        OpenAILLM(model="gpt-4o-mini"),
    ):
        task = TextGeneration(
            name=f"text_generation_with_{llm.model_name}", 
            llm=llm,
            input_mappings={"instruction": "prompt"},
)
        generations.append(task)

    combine_columns = GroupColumns(columns=["generation", "model_name"])

    load_dataset >> random_routing_batch >> generations >> combine_columns


if __name__ == "__main__":
    distiset = pipeline.run(
        use_cache=False,
        parameters={
            "load_dataset": {
                "repo_id": "distilabel-internal-testing/instruction-dataset-mini",
                "num_examples": 3,
                "split": "test",
            },
        }
    )

Expected behaviour
For the serialization to handle this case, and/or for the caching to actually be skipped when specified. Perhaps I am missing something from best practices?

Desktop (please complete the following information):

  • Package version: built from source, happening on both main and develop branches
  • Python version: 3.11.9

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions