-
Couldn't load subscription status.
- Fork 221
Description
Describe the bug
I am experiencing an issue when trying to use a RoutingBatchFunction inside a pipeline. Specifically I am using sample_n_steps() as shown as an example here: https://distilabel.argilla.io/latest/api/pipeline/routing_batch_function/?h=routing#distilabel.pipeline.routing_batch_function.routing_batch_function
The pipeline initially runs without issue, but if I try to run it again it gives AttributeError: 'NoneType' object has no attribute 'name' stemming from distilabel/src/distilabel/pipeline/routing_batch_function.py:170 in dump (on develop). This seems to be a failure with serializing this step.
We are able to get around this temporarily by manually deleting the cache directory on disk, but the error continues to occur even when using use_cache=False in pipeline.run(). Is caching supposed to be required to some degree even when this is specified?
To Reproduce
Code to reproduce
from distilabel.llms import OpenAILLM
from distilabel.pipeline import Pipeline, sample_n_steps
from distilabel.steps import LoadDataFromHub, GroupColumns
from distilabel.steps.tasks import TextGeneration
random_routing_batch = sample_n_steps(2)
with Pipeline(name="routing-batch-function") as pipeline:
load_dataset = LoadDataFromHub(
name="load_dataset",
)
generations = []
for llm in (
OpenAILLM(model="gpt-4o"),
OpenAILLM(model="gpt-4o-mini"),
):
task = TextGeneration(
name=f"text_generation_with_{llm.model_name}",
llm=llm,
input_mappings={"instruction": "prompt"},
)
generations.append(task)
combine_columns = GroupColumns(columns=["generation", "model_name"])
load_dataset >> random_routing_batch >> generations >> combine_columns
if __name__ == "__main__":
distiset = pipeline.run(
use_cache=False,
parameters={
"load_dataset": {
"repo_id": "distilabel-internal-testing/instruction-dataset-mini",
"num_examples": 3,
"split": "test",
},
}
)Expected behaviour
For the serialization to handle this case, and/or for the caching to actually be skipped when specified. Perhaps I am missing something from best practices?
Desktop (please complete the following information):
- Package version: built from source, happening on both main and develop branches
- Python version: 3.11.9