Skip to content

Commit 692e18b

Browse files
author
Clara Luise Pohland
committed
add joblib library
1 parent c7810dd commit 692e18b

File tree

4 files changed

+14
-6
lines changed

4 files changed

+14
-6
lines changed

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@
8989
"mergekit": ["mergekit>=0.0.5.1"],
9090
"peft": ["peft>=0.8.0"],
9191
"quantization": ["bitsandbytes"],
92-
"scikit": ["scikit-learn"],
92+
"scikit": ["scikit-learn", "joblib"],
9393
"test": ["parameterized", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "pytest"],
9494
"vllm": ["vllm>=0.7.1; sys_platform != 'win32'"], # vllm is not available on Windows
9595
"vlm": ["Pillow"],

tests/testing_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from transformers import is_bitsandbytes_available, is_comet_available, is_sklearn_available, is_wandb_available
1919

2020
from trl import BaseBinaryJudge, BasePairwiseJudge, is_diffusers_available, is_llm_blender_available
21-
from trl.import_utils import is_mergekit_available
21+
from trl.import_utils import is_joblib_available, is_mergekit_available
2222

2323

2424
# transformers.testing_utils contains a require_bitsandbytes function, but relies on pytest markers which we don't use
@@ -62,7 +62,7 @@ def require_sklearn(test_case):
6262
"""
6363
Decorator marking a test that requires sklearn. Skips the test if sklearn is not available.
6464
"""
65-
return unittest.skipUnless(is_sklearn_available(), "test requires sklearn")(test_case)
65+
return unittest.skipUnless(is_sklearn_available() and is_joblib_available(), "test requires sklearn")(test_case)
6666

6767

6868
def require_comet(test_case):

trl/import_utils.py

+5
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
_rich_available = _is_package_available("rich")
3030
_unsloth_available = _is_package_available("unsloth")
3131
_vllm_available = _is_package_available("vllm")
32+
_joblib_available = _is_package_available("joblib")
3233

3334

3435
def is_deepspeed_available() -> bool:
@@ -59,6 +60,10 @@ def is_vllm_available() -> bool:
5960
return _vllm_available
6061

6162

63+
def is_joblib_available() -> bool:
64+
return _joblib_available
65+
66+
6267
class _LazyModule(ModuleType):
6368
"""
6469
Module class that surfaces all objects but only performs associated imports when the objects are requested.

trl/trainer/bco_trainer.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from operator import itemgetter
2424
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union
2525

26-
import joblib
2726
import numpy as np
2827
import pandas as pd
2928
import torch
@@ -56,6 +55,7 @@
5655
from transformers.utils import is_peft_available
5756

5857
from ..data_utils import maybe_apply_chat_template
58+
from ..import_utils import is_joblib_available
5959
from ..models import PreTrainedModelWrapper, create_reference_model
6060
from .bco_config import BCOConfig
6161
from .utils import (
@@ -80,6 +80,9 @@
8080
if is_sklearn_available():
8181
from sklearn.linear_model import LogisticRegression
8282

83+
if is_joblib_available():
84+
import joblib
85+
8386
if is_deepspeed_available():
8487
import deepspeed
8588

@@ -350,9 +353,9 @@ def __init__(
350353
embedding_func: Optional[Callable] = None,
351354
embedding_tokenizer: Optional[PreTrainedTokenizerBase] = None,
352355
):
353-
if not is_sklearn_available():
356+
if embedding_func != None and not (is_sklearn_available() and is_joblib_available()):
354357
raise ImportError(
355-
"BCOTrainer requires the scikit-learn library. Please install it with `pip install scikit-learn`."
358+
"BCOTrainer with UDM requires the scikit-learn and joblib libraries. Please install it with `pip install scikit-learn joblib`."
356359
)
357360

358361
if type(args) is TrainingArguments:

0 commit comments

Comments
 (0)