-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcompute_results.py
More file actions
62 lines (52 loc) · 1.95 KB
/
compute_results.py
File metadata and controls
62 lines (52 loc) · 1.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import mteb
import logging
import argparse
from mteb import MTEB
from mteb.models.cache_wrapper import CachedEmbeddingWrapper
from utils.models import CustomModel, CustomDatabricksModel, CustomGoogleModel, CustomRandomModel
from config.eval import RETRIEVAL_TASKS
from dotenv import load_dotenv
load_dotenv()
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
# Dataset args
parser = argparse.ArgumentParser(
prog='Text Embedding Analysis',
description='Parallel processing of text embeddings for MTEB tasks.')
parser.add_argument('--model')
args = parser.parse_args()
if not args.model:
print("No model name provided, launching debug mode with mistral-7b.")
model_name = "intfloat/e5-mistral-7b-instruct"
else:
model_name = args.model
# Select model class
if model_name.startswith("apollo"):
model = CustomModel(model_name=model_name,
normalize_embeddings=True)
elif model_name.startswith("google"):
model = CustomGoogleModel(model_name=model_name)
elif model_name.startswith("databricks"):
model = CustomDatabricksModel(model_name=model_name)
elif model_name.startswith("random"):
model = CustomRandomModel(model_name=model_name)
else:
model = mteb.get_model(model_name)
# Run evaluation
evaluation = MTEB(tasks=RETRIEVAL_TASKS)
cache_name = model_name.replace("/", "_")
cache_path = os.environ.get("CACHE_PATH")
results_path = os.environ.get("MTEB_RESULTS_PATH")
model_with_cached_emb = CachedEmbeddingWrapper(
model, cache_path=f"{cache_path}/cache_{cache_name}"
)
evaluation.run(model_with_cached_emb,
save_predictions=True,
overwrite_results=True,
encode_kwargs={"batch_size": 1,
"normalize_embeddings": True,
"model_name": model_name},
output_folder=f"{results_path}{model_name}",
corpus_chunk_size=1000)
print(f"Completed {model_name}.")