Skip to content

Commit 9734186

Browse files
committed
GCS base_model from pretraining
1 parent 4ac824e commit 9734186

2 files changed

Lines changed: 47 additions & 2 deletions

File tree

src/grouping_trainer/utils.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import logging
12
import math
3+
import os
4+
import subprocess
25
from collections.abc import Iterable
36
from typing import Literal, cast, overload
47

@@ -9,6 +12,11 @@
912
from sentence_transformers import SentenceTransformer as SentenceTransformerOriginal
1013
from transformers import PreTrainedTokenizerBase
1114

15+
logger = logging.getLogger(__name__)
16+
17+
_GCS_PREFIX = "gs://"
18+
_DIR_BASE_MODELS_LOCAL = "_base_models"
19+
1220

1321
def concat_vertical_unordered(
1422
dfs: Iterable[pl.DataFrame],
@@ -122,13 +130,45 @@ def encode(self, sentences: str | list[str] | np.ndarray, **kwargs) -> np.ndarra
122130
return embeddings[[text_to_idx[text] for text in texts]] # assume numpy or torch
123131

124132

133+
def is_gcs_uri(uri: str) -> bool:
134+
return uri.startswith(_GCS_PREFIX)
135+
136+
137+
def assert_gcs_path_exists(uri: str) -> None:
138+
"""
139+
Raises `CalledProcessError` if `uri` doesn't exist (or matches no objects) in GCS.
140+
"""
141+
subprocess.run(
142+
["gcloud", "storage", "ls", uri.rstrip("/") + "/"],
143+
check=True,
144+
stdout=subprocess.DEVNULL,
145+
)
146+
147+
148+
def _download_base_model_from_gcs(uri: str) -> str:
149+
"""
150+
Rsync `uri` (a `gs://...` directory) into `_base_models/<basename>/` relative to CWD and return the local path.
151+
"""
152+
basename = uri.rstrip("/").rsplit("/", 1)[-1]
153+
path_local = os.path.join(_DIR_BASE_MODELS_LOCAL, basename)
154+
logger.info(f"Downloading base model: {uri} -> {path_local}")
155+
subprocess.run(["gcloud", "storage", "rsync", "-r", uri.rstrip("/"), path_local], check=True)
156+
return path_local
157+
158+
125159
def encoder_from_base(base_model: str, use_text_prefix: bool = True) -> SentenceTransformer:
126160
"""
127161
Build a SentenceTransformer encoder with standard dtype/attention settings.
128162
163+
`base_model` is a HuggingFace model ID, a local path, or a `gs://...` path to a custom model directory. gs:// models
164+
are downloaded into `_base_models/` (relative to CWD) on first call.
165+
129166
Handles model-specific quirks (e.g. jina v5's config_kwargs and trust_remote_code) and enables bfloat16 + SDPA when
130167
supported.
131168
"""
169+
if is_gcs_uri(base_model):
170+
base_model = _download_base_model_from_gcs(base_model)
171+
132172
if base_model == "jinaai/jina-embeddings-v5-text-nano-text-matching":
133173
return SentenceTransformer(
134174
base_model,

train.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,9 @@ def run(
4949
Parameters
5050
----------
5151
base_model
52-
HuggingFace model ID or local path for the base encoder. Others we've tried: Alibaba-NLP/gte-modernbert-base,
53-
Qwen/Qwen3-Embedding-0.6B, jinaai/jina-embeddings-v5-text-nano-text-matching
52+
HuggingFace model ID or local path for the base encoder, or a `gs://...` path to a custom model directory in our
53+
bucket (downloaded into `_base_models/` on the instance). Others we've tried:
54+
Alibaba-NLP/gte-modernbert-base, Qwen/Qwen3-Embedding-0.6B, jinaai/jina-embeddings-v5-text-nano-text-matching
5455
run_shortname
5556
Short name for the run. Doesn't need to be unique b/c it's appended to the timestamp.
5657
per_device_token_budget_scale
@@ -79,6 +80,10 @@ def run(
7980
if not tiny_run:
8081
assert run_shortname is not None, "run_shortname is required for full training runs"
8182

83+
# Fail fast on a typo'd gs:// model URI before wasting time launching training.
84+
if gt.utils.is_gcs_uri(base_model):
85+
gt.utils.assert_gcs_path_exists(base_model)
86+
8287
# Generate run_name up front so we can log the artifact URL locally before auto-launching. On the remote, re-use the
8388
# local run_name via env var so both sides log the same GCS path (rather than each generating its own timestamp).
8489
run_name = os.environ.get(_RUN_NAME_ENV_VAR) or gt.launch.run_name_from_shortname(run_shortname or "tiny-run")

0 commit comments

Comments
 (0)