-
Notifications
You must be signed in to change notification settings - Fork 253
Open
Description
dagshub.init() use an URL as MLFLOW_TRACKING_URI, so when the Dagshub integrator uses os.sep it breaks on Windows as it doesn't use the correct separator. On Linux no problems, on MacOS I don't have the possibility to try it, but the separator should be the same of Linux.
import dagshub
from datasets import Dataset
import mlflow
from setfit import SetFitModel, Trainer, TrainingArguments
if __name__ == "__main__":
dagshub.init(repo_owner="XXX", repo_name="YYY", mlflow=True)
mlflow.set_experiment("issue_dagshub")
train_data = Dataset.from_dict({
"text": ["example 1", "example 2", "example 3"],
"label": [[1, 0], [0, 1], [1, 1]]
})
model = SetFitModel.from_pretrained(
"sentence-transformers/paraphrase-MiniLM-L3-v2",
multi_target_strategy="multi-output"
)
with mlflow.start_run(run_name="minimal-test-setfit"):
args = TrainingArguments(
num_epochs=1,
batch_size=2,
num_iterations=1,
report_to="mlflow"
)
trainer = Trainer(
model=model,
args=args,
train_dataset=train_data,
column_mapping={"text": "text", "label": "label"}
)
trainer.train()Accessing as XXX
Initialized MLflow to track repo "XXX/YYY"
Repository XXX/YYY initialized!
model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.
C:\Users\X\Desktop\Progetti\my-first-repo\.venv\Lib\site-packages\mlflow\store\tracking\rest_store.py:211: DeprecationWarning: label() is deprecated. Use is_required() or is_repeated() instead.
req_body = message_to_json(
Applying column mapping to the training dataset
Map: 100%|████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 1367.56 examples/s]
***** Running training *****
Num unique pairs = 6
Batch size = 2
Num epochs = 1
🏃 View run minimal-test-setfit at: X
🧪 View experiment at: X
C:\Users\X\Desktop\Progetti\my-first-repo\.venv\Lib\site-packages\mlflow\store\tracking\rest_store.py:182: DeprecationWarning: label() is deprecated. Use is_required() or is_repeated() instead.
req_body = message_to_json(
Traceback (most recent call last):
File "C:\Users\X\Desktop\Progetti\my-first-repo\minimal_issue.py", line 37, in <module>
trainer.train() # Issue occurs here
^^^^^^^^^^^^^^^
File "C:\Users\X\Desktop\Progetti\my-first-repo\.venv\Lib\site-packages\setfit\trainer.py", line 531, in train
self.train_embeddings(*full_parameters, args=args)
File "C:\Users\X\Desktop\Progetti\my-first-repo\.venv\Lib\site-packages\setfit\trainer.py", line 582, in train_embeddings
self.st_trainer.train()
File "C:\Users\X\Desktop\Progetti\my-first-repo\.venv\Lib\site-packages\transformers\trainer.py", line 2325, in train
return inner_training_loop(
^^^^^^^^^^^^^^^^^^^^
File "C:\Users\X\Desktop\Progetti\my-first-repo\.venv\Lib\site-packages\transformers\trainer.py", line 2573, in _inner_training_loop
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\X\Desktop\Progetti\my-first-repo\.venv\Lib\site-packages\transformers\trainer_callback.py", line 506, in on_train_begin
return self.call_event("on_train_begin", args, state, control)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\X\Desktop\Progetti\my-first-repo\.venv\Lib\site-packages\setfit\trainer.py", line 97, in <lambda>
self.callback_handler.call_event = lambda *args, **kwargs: overwritten_call_event(
^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\X\Desktop\Progetti\my-first-repo\.venv\Lib\site-packages\setfit\trainer.py", line 74, in overwritten_call_event
result = getattr(callback, event)(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\X\Desktop\Progetti\my-first-repo\.venv\Lib\site-packages\transformers\integrations\integration_utils.py", line 1489, in on_train_begin
self.setup(args, state, model)
File "C:\Users\X\Desktop\Progetti\my-first-repo\.venv\Lib\site-packages\transformers\integrations\integration_utils.py", line 1569, in setup
owner=self.remote.split(os.sep)[-2],
~~~~~~~~~~~~~~~~~~~~~~~~~^^^^
IndexError: list index out of range
I'm using setfit==1.1.3.
Metadata
Metadata
Assignees
Labels
No labels