-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathembeddings_utils.py
More file actions
83 lines (64 loc) · 2.61 KB
/
embeddings_utils.py
File metadata and controls
83 lines (64 loc) · 2.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from __future__ import annotations
import os
from pathlib import Path
DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
def _candidate_cache_roots() -> list[Path]:
roots: list[Path] = []
for env_name in ("EMBEDDING_MODEL_PATH", "SENTENCE_TRANSFORMERS_HOME", "HF_HOME", "TRANSFORMERS_CACHE"):
value = os.getenv(env_name)
if value:
roots.append(Path(value))
home = Path.home()
roots.extend(
[
home / ".cache" / "huggingface",
home / ".cache" / "torch" / "sentence_transformers",
]
)
deduped: list[Path] = []
seen: set[Path] = set()
for root in roots:
if root not in seen:
seen.add(root)
deduped.append(root)
return deduped
def resolve_local_embedding_model_path(model_name: str) -> Path | None:
explicit_path = os.getenv("EMBEDDING_MODEL_PATH")
if explicit_path:
candidate = Path(explicit_path).expanduser()
if candidate.exists():
return candidate
model_tail = model_name.split("/")[-1]
repo_dir_name = model_name.replace("/", "--")
for root in _candidate_cache_roots():
direct_candidate = root / model_tail
if (direct_candidate / "modules.json").exists():
return direct_candidate
hub_root = root if root.name == "hub" else root / "hub"
snapshot_root = hub_root / f"models--{repo_dir_name}" / "snapshots"
if not snapshot_root.exists():
continue
snapshots = sorted(
[path for path in snapshot_root.iterdir() if path.is_dir()],
key=lambda path: path.stat().st_mtime,
reverse=True,
)
for snapshot in snapshots:
if (snapshot / "modules.json").exists():
return snapshot
return None
def load_sentence_transformer(model_name: str = DEFAULT_EMBEDDING_MODEL):
from sentence_transformers import SentenceTransformer
local_path = resolve_local_embedding_model_path(model_name)
if local_path is not None:
print(f"Loading embedding model from local path: {local_path}")
return SentenceTransformer(str(local_path), local_files_only=True)
try:
print("Loading embedding model from local cache if available...")
return SentenceTransformer(model_name, local_files_only=True)
except Exception as exc:
raise RuntimeError(
"Failed to load the embedding model from local files. "
"Set EMBEDDING_MODEL_PATH to a local SentenceTransformer directory, "
"or pre-download the model before running offline."
) from exc