Skip to content

Commit 37a314a

Browse files
committed
feat: add @huggingface decorator with pluggable auth
- Pluggable auth: HuggingFaceAuthProvider interface and EnvHuggingFaceAuthProvider (HF_TOKEN / HUGGING_FACE_HUB_TOKEN) - Config METAFLOW_HUGGINGFACE_AUTH_PROVIDER; register hf_auth_provider in plugin system - @huggingface step decorator: declare models/model_mapping; get token via auth provider; download via huggingface_hub; expose current.huggingface.models[key]
1 parent 9770c47 commit 37a314a

12 files changed

Lines changed: 526 additions & 0 deletions

File tree

metaflow/extension_support/plugins.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def resolve_plugins(category, path_only=False):
187187
"datastore": lambda x: x.TYPE,
188188
"dataclient": lambda x: x.TYPE,
189189
"secrets_provider": lambda x: x.TYPE,
190+
"hf_auth_provider": lambda x: x.TYPE,
190191
"gcp_client_provider": lambda x: x.name,
191192
"deployer_impl_provider": lambda x: x.TYPE,
192193
"azure_client_provider": lambda x: x.name,

metaflow/metaflow_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@
4444
DEFAULT_SECRETS_BACKEND_TYPE = from_conf("DEFAULT_SECRETS_BACKEND_TYPE")
4545
DEFAULT_SECRETS_ROLE = from_conf("DEFAULT_SECRETS_ROLE")
4646

47+
# HuggingFace @huggingface decorator: auth provider type (Part 1)
48+
METAFLOW_HUGGINGFACE_AUTH_PROVIDER = from_conf(
49+
"METAFLOW_HUGGINGFACE_AUTH_PROVIDER", "env"
50+
)
51+
4752
DEFAULT_FROM_DEPLOYMENT_IMPL = from_conf(
4853
"DEFAULT_FROM_DEPLOYMENT_IMPL", "argo-workflows"
4954
)
@@ -64,6 +69,7 @@
6469
"timeout",
6570
"conda_env_internal",
6671
"card",
72+
"huggingface",
6773
],
6874
)
6975

metaflow/metaflow_current.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,14 @@ def _raise(ex):
3232
self.__class__.graph = property(
3333
fget=lambda self: _raise(RuntimeError("Graph is not available"))
3434
)
35+
self.__class__.huggingface = property(
36+
fget=lambda self: _raise(
37+
RuntimeError(
38+
"current.huggingface is only available inside a step "
39+
"decorated with @huggingface"
40+
)
41+
)
42+
)
3543

3644
def _set_env(
3745
self,

metaflow/plugins/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
("airflow_internal", ".airflow.airflow_decorator.AirflowInternalDecorator"),
5656
("pypi", ".pypi.pypi_decorator.PyPIStepDecorator"),
5757
("conda", ".pypi.conda_decorator.CondaStepDecorator"),
58+
("huggingface", ".huggingface.huggingface_decorator.HuggingFaceDecorator"),
5859
]
5960

6061
# Add new flow decorators here
@@ -139,6 +140,11 @@
139140

140141
FLOW_DECORATORS_DESC += SENSOR_FLOW_DECORATORS
141142

143+
# HuggingFace decorator: auth providers (Part 1)
144+
HF_AUTH_PROVIDERS_DESC = [
145+
("env", ".huggingface.env_auth_provider.EnvHuggingFaceAuthProvider"),
146+
]
147+
142148
SECRETS_PROVIDERS_DESC = [
143149
("inline", ".secrets.inline_secrets_provider.InlineSecretsProvider"),
144150
(
@@ -213,6 +219,7 @@ def get_runner_cli_path():
213219

214220
AWS_CLIENT_PROVIDERS = resolve_plugins("aws_client_provider")
215221
SECRETS_PROVIDERS = resolve_plugins("secrets_provider")
222+
HF_AUTH_PROVIDERS = resolve_plugins("hf_auth_provider")
216223
AZURE_CLIENT_PROVIDERS = resolve_plugins("azure_client_provider")
217224
GCP_CLIENT_PROVIDERS = resolve_plugins("gcp_client_provider")
218225

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from .auth import HuggingFaceAuthProvider
2+
from .huggingface_decorator import HuggingFaceContext, HuggingFaceDecorator
3+
4+
__all__ = [
5+
"HuggingFaceAuthProvider",
6+
"HuggingFaceContext",
7+
"HuggingFaceDecorator",
8+
]
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""
2+
Pluggable HuggingFace authentication for the @huggingface decorator.
3+
4+
Implement HuggingFaceAuthProvider and register in HF_AUTH_PROVIDERS_DESC
5+
to supply tokens from org-specific backends. If no provider is configured,
6+
the decorator falls back to HF_TOKEN / HUGGING_FACE_HUB_TOKEN environment variables.
7+
"""
8+
9+
import abc
10+
from typing import Optional
11+
12+
13+
class HuggingFaceAuthProvider(abc.ABC):
14+
"""
15+
Interface for pluggable HuggingFace authentication.
16+
Providers are registered via HF_AUTH_PROVIDERS_DESC; the active provider
17+
is selected by METAFLOW_HUGGINGFACE_AUTH_PROVIDER (default: env).
18+
"""
19+
20+
# Unique identifier for this provider (e.g. "env", "netflix-internal")
21+
TYPE = None # type: Optional[str]
22+
23+
@abc.abstractmethod
24+
def get_token(self) -> Optional[str]:
25+
"""
26+
Return the HuggingFace API token to use for this task, or None if no auth.
27+
"""
28+
pass
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""
2+
Default HuggingFace auth provider: read token from environment variables.
3+
"""
4+
5+
import os
6+
from typing import Optional
7+
8+
from .auth import HuggingFaceAuthProvider
9+
10+
11+
class EnvHuggingFaceAuthProvider(HuggingFaceAuthProvider):
12+
"""
13+
Supplies HuggingFace token from HF_TOKEN or HUGGING_FACE_HUB_TOKEN.
14+
"""
15+
16+
TYPE = "env"
17+
18+
def get_token(self) -> Optional[str]:
19+
return os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""
2+
Minimal flow to test @huggingface locally.
3+
4+
Run from repo root:
5+
PYTHONPATH=. python -m metaflow.plugins.huggingface.example_flow run
6+
7+
Requires: pip install huggingface_hub
8+
Optional: set HF_TOKEN or HUGGING_FACE_HUB_TOKEN for gated models.
9+
"""
10+
11+
from metaflow import FlowSpec, step, huggingface
12+
13+
14+
class HuggingFaceExampleFlow(FlowSpec):
15+
@huggingface(models=["bert-base-uncased"])
16+
@step
17+
def start(self):
18+
from metaflow import current
19+
20+
self.model_path = current.huggingface.models["bert-base-uncased"]
21+
print("Model path:", self.model_path)
22+
self.next(self.end)
23+
24+
@step
25+
def end(self):
26+
print("Done. Model was at:", self.model_path)
27+
28+
29+
if __name__ == "__main__":
30+
HuggingFaceExampleFlow()
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
"""
2+
@huggingface step decorator: pluggable auth for HuggingFace models (Part 1).
3+
4+
Provides current.huggingface.models[key] -> local path. Supports models=[] and
5+
model_mapping={alias: repo_id@revision}. Uses huggingface_hub for download.
6+
"""
7+
8+
import os
9+
from typing import Dict, List, Optional, Tuple
10+
11+
from metaflow.decorators import StepDecorator
12+
from metaflow.exception import MetaflowException
13+
from metaflow.metaflow_current import current
14+
15+
16+
# Minimal object exposed as current.huggingface with a .models mapping
17+
class HuggingFaceContext:
18+
"""
19+
Context object attached to current.huggingface when @huggingface is used.
20+
models maps user-facing key (alias or repo_id) to local filesystem path (str).
21+
"""
22+
23+
def __init__(self, models: Dict[str, str]):
24+
self.models = models
25+
26+
27+
def _parse_repo_spec(value: str) -> Tuple[str, str]:
28+
"""Parse 'repo_id' or 'repo_id@revision' into (repo_id, revision)."""
29+
value = (value or "").strip()
30+
if not value:
31+
raise MetaflowException(
32+
"@huggingface: empty model spec; use repo_id or repo_id@revision"
33+
)
34+
if "@" in value:
35+
repo_id, revision = value.rsplit("@", 1)
36+
repo_id = repo_id.strip()
37+
revision = revision.strip()
38+
if not repo_id or not revision:
39+
raise MetaflowException(
40+
"@huggingface: invalid spec '%s'; use repo_id@revision" % value
41+
)
42+
return repo_id, revision
43+
return value, "main"
44+
45+
46+
def _build_spec_map(
47+
models: Optional[List[str]], model_mapping: Optional[Dict[str, str]]
48+
) -> Dict[str, Tuple[str, str]]:
49+
"""Build key -> (repo_id, revision). Key is alias or repo_id."""
50+
spec_map = {}
51+
if models:
52+
for v in models:
53+
if not isinstance(v, str):
54+
raise MetaflowException(
55+
"@huggingface: models must be a list of strings, got %s" % type(v)
56+
)
57+
repo_id, revision = _parse_repo_spec(v)
58+
spec_map[repo_id] = (repo_id, revision)
59+
if model_mapping:
60+
for k, v in model_mapping.items():
61+
if not isinstance(k, str) or not isinstance(v, str):
62+
raise MetaflowException(
63+
"@huggingface: model_mapping must be dict of str -> str"
64+
)
65+
repo_id, revision = _parse_repo_spec(v)
66+
spec_map[k] = (repo_id, revision)
67+
return spec_map
68+
69+
70+
def _get_auth_provider():
71+
from metaflow.metaflow_config import METAFLOW_HUGGINGFACE_AUTH_PROVIDER
72+
from metaflow.plugins import HF_AUTH_PROVIDERS
73+
74+
provider_type = METAFLOW_HUGGINGFACE_AUTH_PROVIDER or "env"
75+
provider_cls = next(
76+
(p for p in HF_AUTH_PROVIDERS if getattr(p, "TYPE", None) == provider_type),
77+
None,
78+
)
79+
if provider_cls is None:
80+
from metaflow.plugins.huggingface.env_auth_provider import (
81+
EnvHuggingFaceAuthProvider,
82+
)
83+
84+
return EnvHuggingFaceAuthProvider()
85+
return provider_cls()
86+
87+
88+
def _download_model(
89+
repo_id: str, revision: str, token: Optional[str], local_dir: str
90+
) -> str:
91+
try:
92+
from huggingface_hub import snapshot_download
93+
except ImportError as e:
94+
raise MetaflowException(
95+
"@huggingface requires the 'huggingface_hub' package. "
96+
"Install it with: pip install huggingface_hub. Error: %s" % e
97+
) from e
98+
path = snapshot_download(
99+
repo_id=repo_id,
100+
revision=revision,
101+
token=token,
102+
local_dir=local_dir,
103+
local_dir_use_symlinks=False,
104+
)
105+
return path
106+
107+
108+
class HuggingFaceDecorator(StepDecorator):
109+
"""
110+
Declares HuggingFace models needed for this step. Auth is pluggable;
111+
model paths are exposed via current.huggingface.models[key].
112+
113+
Parameters
114+
----------
115+
models : list, optional
116+
List of repo ids (and optional revisions), e.g.
117+
["meta-llama/Llama-2-7b", "bert-base-uncased@v1.0"].
118+
model_mapping : dict, optional
119+
Alias -> repo spec, e.g.
120+
{"llama": "meta-llama/Llama-2-7b@main", "bert": "bert-base-uncased"}.
121+
Access in step via current.huggingface.models["llama"].
122+
123+
MF Add To Current
124+
-----------------
125+
huggingface -> HuggingFaceContext
126+
Object with a ``models`` attribute: dict-like mapping from model key
127+
(alias or repo_id) to local filesystem path (str). Use
128+
current.huggingface.models["key"] to get the path for loading with
129+
transformers or other HF APIs.
130+
"""
131+
132+
name = "huggingface"
133+
defaults = {"models": None, "model_mapping": None}
134+
135+
def step_init(
136+
self, flow, graph, step_name, decorators, environment, flow_datastore, logger
137+
):
138+
models = self.attributes.get("models")
139+
model_mapping = self.attributes.get("model_mapping")
140+
if not models and not model_mapping:
141+
raise MetaflowException(
142+
"@huggingface: specify at least one of 'models' or 'model_mapping'"
143+
)
144+
self._spec_map = _build_spec_map(models, model_mapping)
145+
if not self._spec_map:
146+
raise MetaflowException(
147+
"@huggingface: at least one model or model_mapping entry is required"
148+
)
149+
150+
def task_pre_step(
151+
self,
152+
step_name,
153+
task_datastore,
154+
metadata,
155+
run_id,
156+
task_id,
157+
flow,
158+
graph,
159+
retry_count,
160+
max_user_code_retries,
161+
ubf_context,
162+
inputs,
163+
):
164+
token = None
165+
try:
166+
auth_provider = _get_auth_provider()
167+
token = auth_provider.get_token()
168+
except Exception as e:
169+
raise MetaflowException(
170+
"@huggingface: auth provider failed: %s" % e
171+
) from e
172+
173+
base_dir = os.path.join(current.tempdir or "/tmp", "metaflow_huggingface")
174+
os.makedirs(base_dir, exist_ok=True)
175+
path_map = {} # key -> local path
176+
177+
for key, (repo_id, revision) in self._spec_map.items():
178+
task_subdir = os.path.join(
179+
base_dir, "%s_%s" % (repo_id.replace("/", "_"), revision)
180+
)
181+
local_path = _download_model(repo_id, revision, token, task_subdir)
182+
path_map[key] = local_path
183+
184+
ctx = HuggingFaceContext(models=path_map)
185+
current._update_env({"huggingface": ctx})

run_huggingface_demo.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#!/usr/bin/env python
2+
"""
3+
Demo script for the @huggingface decorator.
4+
5+
Run from repo root:
6+
python run_huggingface_demo.py run
7+
8+
Or run a single step (no download):
9+
python run_huggingface_demo.py run --no-pylint 2>&1 | head -30
10+
11+
Requires: pip install -e . and pip install huggingface_hub
12+
Optional: HF_TOKEN or HUGGING_FACE_HUB_TOKEN for gated models.
13+
"""
14+
from metaflow import FlowSpec, step, huggingface, current
15+
16+
17+
class HuggingFaceDemoFlow(FlowSpec):
18+
"""Single step: download bert-base-uncased via @huggingface and print path."""
19+
20+
@huggingface(models=["bert-base-uncased"])
21+
@step
22+
def start(self):
23+
path = current.huggingface.models["bert-base-uncased"]
24+
print("Model path:", path)
25+
self.model_path = path
26+
self.next(self.end)
27+
28+
@step
29+
def end(self):
30+
print("Done. Model at:", self.model_path)
31+
32+
33+
if __name__ == "__main__":
34+
HuggingFaceDemoFlow()

0 commit comments

Comments
 (0)