Skip to content

Commit ba4b1a6

Browse files
authored
Merge pull request #1777 from PolicyEngine/codex/uk-gcs-dataset-loading
Support direct GCS dataset sources
2 parents 8f31a59 + 6915718 commit ba4b1a6

6 files changed

Lines changed: 457 additions & 5 deletions

File tree

changelog.d/1776.added.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
- Added direct `gs://` dataset loading for UK simulations, including support for GCS generations and PolicyEngine data-version metadata.

docs/book/usage/simulations.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,19 @@ sim = Simulation(dataset=dataset)
278278
print(sim.calculate("household_net_income", 2026))
279279
```
280280

281+
`Simulation` and `Microsimulation` can also load H5 files from local paths,
282+
Hugging Face URLs, or Google Cloud Storage URLs:
283+
284+
```python
285+
sim = Microsimulation(
286+
dataset="gs://policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.55.10"
287+
)
288+
```
289+
290+
For `gs://` URLs, a numeric suffix after `@` pins an exact GCS generation. A
291+
non-numeric suffix pins the PolicyEngine data version stored in the object's
292+
GCS metadata.
293+
281294
### From survey datasets
282295

283296
For population-level analysis, use survey data:
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import hashlib
2+
import os
3+
import tempfile
4+
from pathlib import Path
5+
from typing import Optional, Union
6+
7+
from policyengine_core.tools.google_cloud import parse_gs_url
8+
9+
10+
def materialize_gcs_dataset_url(
11+
dataset_url: str,
12+
*,
13+
cache_dir: Optional[Union[str, os.PathLike]] = None,
14+
) -> str:
15+
"""Download a GCS dataset URL to a local H5 path and return that path."""
16+
bucket_name, file_path, revision = parse_gs_url(dataset_url)
17+
storage_client = _get_storage_client()
18+
blob = _resolve_gcs_blob(storage_client, bucket_name, file_path, revision)
19+
generation = _blob_generation(blob)
20+
21+
local_path = _cached_dataset_path(
22+
bucket_name=bucket_name,
23+
file_path=file_path,
24+
generation=generation,
25+
cache_dir=cache_dir,
26+
)
27+
if not local_path.exists():
28+
_download_blob(blob, local_path)
29+
return str(local_path)
30+
31+
32+
def _get_storage_client():
33+
try:
34+
import google.auth
35+
from google.auth import exceptions as auth_exceptions
36+
from google.cloud import storage
37+
except ImportError as exc:
38+
raise ImportError(
39+
"google-cloud-storage is required for gs:// dataset URLs. "
40+
"Install it with: pip install google-cloud-storage"
41+
) from exc
42+
43+
try:
44+
credentials, project_id = google.auth.default()
45+
except auth_exceptions.DefaultCredentialsError as exc:
46+
raise RuntimeError(
47+
"Google Cloud credentials are required for gs:// dataset URLs. "
48+
"Set application default credentials or GOOGLE_APPLICATION_CREDENTIALS."
49+
) from exc
50+
51+
return storage.Client(credentials=credentials, project=project_id)
52+
53+
54+
def _resolve_gcs_blob(
55+
storage_client,
56+
bucket_name: str,
57+
file_path: str,
58+
revision: Optional[str],
59+
):
60+
bucket = storage_client.bucket(bucket_name)
61+
62+
if revision is not None and revision.isdigit():
63+
blob = bucket.blob(file_path, generation=int(revision))
64+
blob.reload()
65+
return blob
66+
67+
current_blob = bucket.blob(file_path)
68+
current_blob.reload()
69+
if revision is None or _blob_metadata_version(current_blob) == revision:
70+
return current_blob
71+
72+
matching_blobs = []
73+
for blob in storage_client.list_blobs(
74+
bucket_name,
75+
prefix=file_path,
76+
versions=True,
77+
):
78+
if blob.name != file_path:
79+
continue
80+
if _blob_metadata_version(blob) == revision:
81+
matching_blobs.append(blob)
82+
83+
if not matching_blobs:
84+
raise ValueError(
85+
f"No GCS object version for gs://{bucket_name}/{file_path} has "
86+
f"metadata version {revision!r}."
87+
)
88+
89+
return max(matching_blobs, key=lambda blob: int(_blob_generation(blob)))
90+
91+
92+
def _blob_metadata_version(blob) -> Optional[str]:
93+
if getattr(blob, "metadata", None) is None:
94+
blob.reload()
95+
metadata = getattr(blob, "metadata", None) or {}
96+
return metadata.get("version")
97+
98+
99+
def _blob_generation(blob) -> str:
100+
generation = getattr(blob, "generation", None)
101+
if generation is None:
102+
blob.reload()
103+
generation = getattr(blob, "generation", None)
104+
if generation is None:
105+
raise ValueError(f"GCS object {blob.name!r} does not expose a generation.")
106+
return str(generation)
107+
108+
109+
def _cached_dataset_path(
110+
*,
111+
bucket_name: str,
112+
file_path: str,
113+
generation: str,
114+
cache_dir: Optional[Union[str, os.PathLike]],
115+
) -> Path:
116+
if cache_dir is None:
117+
cache_dir = Path(tempfile.gettempdir()) / "policyengine-uk-datasets"
118+
else:
119+
cache_dir = Path(cache_dir)
120+
121+
cache_key = hashlib.sha256(
122+
f"{bucket_name}\0{file_path}\0{generation}".encode()
123+
).hexdigest()
124+
return cache_dir / cache_key / Path(file_path).name
125+
126+
127+
def _download_blob(blob, local_path: Path) -> None:
128+
local_path.parent.mkdir(parents=True, exist_ok=True)
129+
fd, temporary_path_name = tempfile.mkstemp(
130+
prefix=f".{local_path.name}.",
131+
suffix=".tmp",
132+
dir=local_path.parent,
133+
)
134+
os.close(fd)
135+
temporary_path = Path(temporary_path_name)
136+
try:
137+
blob.download_to_filename(str(temporary_path))
138+
os.replace(temporary_path, local_path)
139+
finally:
140+
temporary_path.unlink(missing_ok=True)

policyengine_uk/simulation.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
extend_single_year_dataset,
2727
reset_growthfactor_uprating,
2828
)
29+
from policyengine_uk.data.dataset_sources import materialize_gcs_dataset_url
2930
from policyengine_uk.utils.dependencies import get_variable_dependencies
3031
from policyengine_uk.reforms import create_structural_reforms_from_parameters
3132
from policyengine_uk.parameters.gov.simulation.labour_supply_responses.aliases import (
@@ -274,11 +275,19 @@ def build_from_dataset_source(
274275
if dataset_source.startswith("hf://"):
275276
self.build_from_url(dataset_source)
276277
return
278+
if dataset_source.startswith("gs://"):
279+
if dataset_source in _url_dataset_cache:
280+
multi_year_dataset = _url_dataset_cache[dataset_source]
281+
self.build_from_multi_year_dataset(multi_year_dataset)
282+
self.dataset = multi_year_dataset
283+
return
284+
dataset_file = materialize_gcs_dataset_url(dataset_source)
285+
self.build_from_file(dataset_file, cache_key=dataset_source)
286+
return
277287
if "://" in dataset_source:
278288
raise ValueError(
279-
"Only HuggingFace dataset URLs are supported directly by "
280-
"policyengine-uk. Download or materialize other dataset "
281-
"sources to a local file path before passing them to Simulation."
289+
"Only HuggingFace, Google Cloud Storage, and local dataset "
290+
"sources are supported by policyengine-uk."
282291
)
283292
self.build_from_file(dataset_source)
284293

0 commit comments

Comments
 (0)