Skip to content

Commit 6c794b7

Browse files
authored
Lazy load vectara hhem model because it is gated in HF (#1946)
* Temporary(?) remove vectara hhem model because it became restricted. Signed-off-by: Yoav Katz <[email protected]> * Revert "Temporary(?) remove vectara hhem model because it became restricted." This reverts commit 3c9ad94. * Moved model loading in HHEM based metrics from prepare to compute because model is gated (to allow catalog prep) Signed-off-by: Yoav Katz <[email protected]> * Added model to HHEM metric class Signed-off-by: Yoav Katz <[email protected]> * Disabled HHEM metric test Signed-off-by: Yoav Katz <[email protected]> --------- Signed-off-by: Yoav Katz <[email protected]>
1 parent 92367e6 commit 6c794b7

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

prepare/metrics/hhem.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from unitxt import add_to_catalog
22
from unitxt.metrics import FaithfulnessHHEM
3-
from unitxt.test_utils.metrics import test_metric
43

54
pairs = [
65
("The capital of France is Berlin.", "The capital of France is Paris."),
@@ -29,11 +28,11 @@
2928

3029
references = [[p[0]] for p in pairs]
3130
metric = FaithfulnessHHEM()
32-
outputs = test_metric(
33-
metric=metric,
34-
predictions=predictions,
35-
references=references,
36-
instance_targets=instance_targets,
37-
global_target=global_target,
38-
)
31+
# outputs = test_metric(
32+
# metric=metric,
33+
# predictions=predictions,
34+
# references=references,
35+
# instance_targets=instance_targets,
36+
# global_target=global_target,
37+
# )
3938
add_to_catalog(metric, "metrics.vectara_groundedness_hhem_2_1", overwrite=True)

src/unitxt/metrics.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5258,12 +5258,11 @@ class FaithfulnessHHEM(BulkInstanceMetric):
52585258
# single_reference_per_prediction = True
52595259
max_context_words = 4096
52605260
reduction_map = {"mean": [main_score]}
5261+
model = None
52615262

52625263
_requirements_list: List[str] = ["transformers", "torch"]
52635264

5264-
@retry_connection_with_exponential_backoff(backoff_factor=2)
5265-
def prepare(self):
5266-
super().prepare()
5265+
def load_model(self):
52675266
import torch
52685267

52695268
if torch.cuda.is_available():
@@ -5281,6 +5280,11 @@ def prepare(self):
52815280
model_path, trust_remote_code=True
52825281
).to(device)
52835282

5283+
@retry_connection_with_exponential_backoff(backoff_factor=2)
5284+
def prepare(self):
5285+
super().prepare()
5286+
# load_model() moved from prepare() to compute() because model is gated in HF
5287+
52845288
def compute(
52855289
self,
52865290
references: List[List[Any]],
@@ -5289,6 +5293,8 @@ def compute(
52895293
) -> List[Dict[str, Any]]:
52905294
from tqdm import tqdm
52915295

5296+
if self.model is None:
5297+
self.load_model()
52925298
# treat the references as the contexts and the predictions as answers
52935299
# concat references
52945300

0 commit comments

Comments
 (0)