|
| 1 | +import logging |
1 | 2 | import math |
| 3 | +import os |
| 4 | +import subprocess |
2 | 5 | from collections.abc import Iterable |
3 | 6 | from typing import Literal, cast, overload |
4 | 7 |
|
|
9 | 12 | from sentence_transformers import SentenceTransformer as SentenceTransformerOriginal |
10 | 13 | from transformers import PreTrainedTokenizerBase |
11 | 14 |
|
| 15 | +logger = logging.getLogger(__name__) |
| 16 | + |
| 17 | +_GCS_PREFIX = "gs://" |
| 18 | +_DIR_BASE_MODELS_LOCAL = "_base_models" |
| 19 | + |
12 | 20 |
|
13 | 21 | def concat_vertical_unordered( |
14 | 22 | dfs: Iterable[pl.DataFrame], |
@@ -122,13 +130,45 @@ def encode(self, sentences: str | list[str] | np.ndarray, **kwargs) -> np.ndarra |
122 | 130 | return embeddings[[text_to_idx[text] for text in texts]] # assume numpy or torch |
123 | 131 |
|
124 | 132 |
|
| 133 | +def is_gcs_uri(uri: str) -> bool: |
| 134 | + return uri.startswith(_GCS_PREFIX) |
| 135 | + |
| 136 | + |
| 137 | +def assert_gcs_path_exists(uri: str) -> None: |
| 138 | + """ |
| 139 | + Raises `CalledProcessError` if `uri` doesn't exist (or matches no objects) in GCS. |
| 140 | + """ |
| 141 | + subprocess.run( |
| 142 | + ["gcloud", "storage", "ls", uri.rstrip("/") + "/"], |
| 143 | + check=True, |
| 144 | + stdout=subprocess.DEVNULL, |
| 145 | + ) |
| 146 | + |
| 147 | + |
| 148 | +def _download_base_model_from_gcs(uri: str) -> str: |
| 149 | + """ |
| 150 | + Rsync `uri` (a `gs://...` directory) into `_base_models/<basename>/` relative to CWD and return the local path. |
| 151 | + """ |
| 152 | + basename = uri.rstrip("/").rsplit("/", 1)[-1] |
| 153 | + path_local = os.path.join(_DIR_BASE_MODELS_LOCAL, basename) |
| 154 | + logger.info(f"Downloading base model: {uri} -> {path_local}") |
| 155 | + subprocess.run(["gcloud", "storage", "rsync", "-r", uri.rstrip("/"), path_local], check=True) |
| 156 | + return path_local |
| 157 | + |
| 158 | + |
125 | 159 | def encoder_from_base(base_model: str, use_text_prefix: bool = True) -> SentenceTransformer: |
126 | 160 | """ |
127 | 161 | Build a SentenceTransformer encoder with standard dtype/attention settings. |
128 | 162 |
|
| 163 | + `base_model` is a HuggingFace model ID, a local path, or a `gs://...` path to a custom model directory. gs:// models |
| 164 | + are downloaded into `_base_models/` (relative to CWD) on first call. |
| 165 | +
|
129 | 166 | Handles model-specific quirks (e.g. jina v5's config_kwargs and trust_remote_code) and enables bfloat16 + SDPA when |
130 | 167 | supported. |
131 | 168 | """ |
| 169 | + if is_gcs_uri(base_model): |
| 170 | + base_model = _download_base_model_from_gcs(base_model) |
| 171 | + |
132 | 172 | if base_model == "jinaai/jina-embeddings-v5-text-nano-text-matching": |
133 | 173 | return SentenceTransformer( |
134 | 174 | base_model, |
|
0 commit comments