Skip to content

DagsHubCallback fails on Windows #621

@mattiacurri

Description

@mattiacurri

dagshub.init() use an URL as MLFLOW_TRACKING_URI, so when the Dagshub integrator uses os.sep it breaks on Windows as it doesn't use the correct separator. On Linux no problems, on MacOS I don't have the possibility to try it, but the separator should be the same of Linux.

import dagshub
from datasets import Dataset
import mlflow
from setfit import SetFitModel, Trainer, TrainingArguments

if __name__ == "__main__":
    dagshub.init(repo_owner="XXX", repo_name="YYY", mlflow=True)
    
    mlflow.set_experiment("issue_dagshub")
    
    train_data = Dataset.from_dict({
        "text": ["example 1", "example 2", "example 3"],
        "label": [[1, 0], [0, 1], [1, 1]]
    })
    
    model = SetFitModel.from_pretrained(
        "sentence-transformers/paraphrase-MiniLM-L3-v2",
        multi_target_strategy="multi-output"
    )
    
    with mlflow.start_run(run_name="minimal-test-setfit"):
        args = TrainingArguments(
            num_epochs=1,
            batch_size=2,
            num_iterations=1,
            report_to="mlflow"
        )
        
        trainer = Trainer(
            model=model,
            args=args,
            train_dataset=train_data,
            column_mapping={"text": "text", "label": "label"}
        )
        
        trainer.train()
Accessing as XXX
Initialized MLflow to track repo "XXX/YYY"
Repository XXX/YYY initialized!
model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.
C:\Users\X\Desktop\Progetti\my-first-repo\.venv\Lib\site-packages\mlflow\store\tracking\rest_store.py:211: DeprecationWarning: label() is deprecated. Use is_required() or is_repeated() instead.
  req_body = message_to_json(
Applying column mapping to the training dataset
Map: 100%|████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 1367.56 examples/s]
***** Running training *****
  Num unique pairs = 6
  Batch size = 2
  Num epochs = 1
🏃 View run minimal-test-setfit at: X
🧪 View experiment at: X
C:\Users\X\Desktop\Progetti\my-first-repo\.venv\Lib\site-packages\mlflow\store\tracking\rest_store.py:182: DeprecationWarning: label() is deprecated. Use is_required() or is_repeated() instead.
  req_body = message_to_json(
Traceback (most recent call last):
  File "C:\Users\X\Desktop\Progetti\my-first-repo\minimal_issue.py", line 37, in <module>
    trainer.train()  # Issue occurs here
    ^^^^^^^^^^^^^^^
  File "C:\Users\X\Desktop\Progetti\my-first-repo\.venv\Lib\site-packages\setfit\trainer.py", line 531, in train  
    self.train_embeddings(*full_parameters, args=args)
  File "C:\Users\X\Desktop\Progetti\my-first-repo\.venv\Lib\site-packages\setfit\trainer.py", line 582, in train_embeddings
    self.st_trainer.train()
  File "C:\Users\X\Desktop\Progetti\my-first-repo\.venv\Lib\site-packages\transformers\trainer.py", line 2325, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\X\Desktop\Progetti\my-first-repo\.venv\Lib\site-packages\transformers\trainer.py", line 2573, in _inner_training_loop
    self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\X\Desktop\Progetti\my-first-repo\.venv\Lib\site-packages\transformers\trainer_callback.py", line 506, in on_train_begin
    return self.call_event("on_train_begin", args, state, control)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\X\Desktop\Progetti\my-first-repo\.venv\Lib\site-packages\setfit\trainer.py", line 97, in <lambda>
    self.callback_handler.call_event = lambda *args, **kwargs: overwritten_call_event(
                                                               ^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\X\Desktop\Progetti\my-first-repo\.venv\Lib\site-packages\setfit\trainer.py", line 74, in overwritten_call_event
    result = getattr(callback, event)(
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\X\Desktop\Progetti\my-first-repo\.venv\Lib\site-packages\transformers\integrations\integration_utils.py", line 1489, in on_train_begin
    self.setup(args, state, model)
  File "C:\Users\X\Desktop\Progetti\my-first-repo\.venv\Lib\site-packages\transformers\integrations\integration_utils.py", line 1569, in setup
    owner=self.remote.split(os.sep)[-2],
          ~~~~~~~~~~~~~~~~~~~~~~~~~^^^^
IndexError: list index out of range

I'm using setfit==1.1.3.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions