Skip to content

Commit 6915718

Browse files
committed
Harden GCS dataset loading
1 parent aadf2f7 commit 6915718

4 files changed

Lines changed: 83 additions & 1 deletion

File tree

policyengine_uk/data/dataset_sources.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,13 @@ def _cached_dataset_path(
126126

127127
def _download_blob(blob, local_path: Path) -> None:
128128
local_path.parent.mkdir(parents=True, exist_ok=True)
129-
temporary_path = local_path.with_name(f"{local_path.name}.tmp")
129+
fd, temporary_path_name = tempfile.mkstemp(
130+
prefix=f".{local_path.name}.",
131+
suffix=".tmp",
132+
dir=local_path.parent,
133+
)
134+
os.close(fd)
135+
temporary_path = Path(temporary_path_name)
130136
try:
131137
blob.download_to_filename(str(temporary_path))
132138
os.replace(temporary_path, local_path)

policyengine_uk/simulation.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,11 @@ def build_from_dataset_source(
276276
self.build_from_url(dataset_source)
277277
return
278278
if dataset_source.startswith("gs://"):
279+
if dataset_source in _url_dataset_cache:
280+
multi_year_dataset = _url_dataset_cache[dataset_source]
281+
self.build_from_multi_year_dataset(multi_year_dataset)
282+
self.dataset = multi_year_dataset
283+
return
279284
dataset_file = materialize_gcs_dataset_url(dataset_source)
280285
self.build_from_file(dataset_file, cache_key=dataset_source)
281286
return

policyengine_uk/tests/test_dataset_sources.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from pathlib import Path
23

34
import pytest
@@ -13,13 +14,15 @@ def __init__(self, name, generation, metadata=None, contents=b"dataset"):
1314
self.metadata = metadata
1415
self.contents = contents
1516
self.download_count = 0
17+
self.download_filenames = []
1618
self.reload_count = 0
1719

1820
def reload(self):
1921
self.reload_count += 1
2022

2123
def download_to_filename(self, filename):
2224
self.download_count += 1
25+
self.download_filenames.append(filename)
2326
Path(filename).write_bytes(self.contents)
2427

2528

@@ -187,3 +190,33 @@ def test_materialize_gcs_dataset_url_reuses_cached_file(monkeypatch, tmp_path):
187190
assert first_path == second_path
188191
assert Path(second_path).read_bytes() == b"current"
189192
assert current_blob.download_count == 1
193+
194+
195+
def test_download_blob_uses_unique_temp_path_for_each_download(monkeypatch, tmp_path):
196+
local_path = tmp_path / "cache" / "file.h5"
197+
created_temp_paths = []
198+
199+
def fake_mkstemp(*, prefix, suffix, dir):
200+
temporary_path = Path(dir) / f"{prefix}{len(created_temp_paths)}{suffix}"
201+
fd = os.open(temporary_path, os.O_CREAT | os.O_EXCL | os.O_RDWR, 0o600)
202+
created_temp_paths.append(temporary_path)
203+
return fd, str(temporary_path)
204+
205+
monkeypatch.setattr(dataset_sources.tempfile, "mkstemp", fake_mkstemp)
206+
blob = FakeBlob("data/file.h5", 444, contents=b"first")
207+
208+
dataset_sources._download_blob(blob, local_path)
209+
local_path.unlink()
210+
blob.contents = b"second"
211+
dataset_sources._download_blob(blob, local_path)
212+
213+
assert [
214+
Path(filename) for filename in blob.download_filenames
215+
] == created_temp_paths
216+
assert len(set(created_temp_paths)) == 2
217+
assert all(
218+
temporary_path.parent == local_path.parent
219+
for temporary_path in created_temp_paths
220+
)
221+
assert all(not temporary_path.exists() for temporary_path in created_temp_paths)
222+
assert local_path.read_bytes() == b"second"

policyengine_uk/tests/test_simulation_dataset_sources.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,44 @@ def fake_build_from_file(dataset_file, *, cache_key=None):
8585
}
8686

8787

88+
def test_dataset_source_reuses_cached_gcs_dataset_before_materializing(monkeypatch):
89+
captured = {}
90+
cached_dataset = object()
91+
simulation = Simulation.__new__(Simulation)
92+
url = "gs://policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.55.10"
93+
94+
def fake_materialize_gcs_dataset_url(dataset_url):
95+
raise AssertionError("Cached gs:// datasets should not be materialized.")
96+
97+
def fake_build_from_file(dataset_file, *, cache_key=None):
98+
raise AssertionError("Cached gs:// datasets should not be read from disk.")
99+
100+
def fake_build_from_multi_year_dataset(dataset):
101+
captured["dataset"] = dataset
102+
103+
simulation_module._url_dataset_cache.pop(url, None)
104+
simulation_module._url_dataset_cache[url] = cached_dataset
105+
monkeypatch.setattr(
106+
simulation_module,
107+
"materialize_gcs_dataset_url",
108+
fake_materialize_gcs_dataset_url,
109+
)
110+
monkeypatch.setattr(simulation, "build_from_file", fake_build_from_file)
111+
monkeypatch.setattr(
112+
simulation,
113+
"build_from_multi_year_dataset",
114+
fake_build_from_multi_year_dataset,
115+
)
116+
117+
try:
118+
Simulation.build_from_dataset_source(simulation, url)
119+
120+
assert captured["dataset"] is cached_dataset
121+
assert simulation.dataset is cached_dataset
122+
finally:
123+
simulation_module._url_dataset_cache.pop(url, None)
124+
125+
88126
def test_dataset_source_rejects_unsupported_remote_urls():
89127
simulation = Simulation.__new__(Simulation)
90128

0 commit comments

Comments
 (0)