Skip to content

Commit 4313077

Browse files
Address reviews
1 parent b8407e3 commit 4313077

File tree

5 files changed

+43
-26
lines changed

5 files changed

+43
-26
lines changed

keras_remote/backend/execution.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from keras_remote.constants import get_default_zone, zone_to_region
2020
from keras_remote.credentials import ensure_credentials
2121
from keras_remote.infra import container_builder
22+
from keras_remote.infra.infra import get_default_project
2223
from keras_remote.utils import packager, storage
2324

2425

@@ -73,9 +74,7 @@ def from_params(
7374
if not zone:
7475
zone = get_default_zone()
7576
if not project:
76-
project = os.environ.get("KERAS_REMOTE_PROJECT") or os.environ.get(
77-
"GOOGLE_CLOUD_PROJECT"
78-
)
77+
project = get_default_project()
7978
if not project:
8079
raise ValueError(
8180
"project must be specified or set KERAS_REMOTE_PROJECT"

keras_remote/data.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,9 @@ def __repr__(self):
101101
return f"Data({self._raw_path!r})"
102102

103103

104-
def _make_data_ref(gcs_uri, is_dir, mount_path=None):
104+
def _make_data_ref(
105+
gcs_uri: str, is_dir: bool, mount_path: str | None = None
106+
) -> dict[str, object]:
105107
"""Create a serializable data reference dict.
106108
107109
These dicts replace Data objects in the payload before serialization.
@@ -115,6 +117,6 @@ def _make_data_ref(gcs_uri, is_dir, mount_path=None):
115117
}
116118

117119

118-
def is_data_ref(obj):
120+
def is_data_ref(obj: object) -> bool:
119121
"""Check if an object is a serialized data reference."""
120122
return isinstance(obj, dict) and obj.get("__data_ref__") is True

keras_remote/infra/infra.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,8 @@
55
logger = logging.getLogger("keras_remote")
66

77

8-
def get_default_project():
9-
return os.environ.get("KERAS_REMOTE_PROJECT")
8+
def get_default_project() -> str | None:
9+
"""Get project ID from KERAS_REMOTE_PROJECT or GOOGLE_CLOUD_PROJECT."""
10+
return os.environ.get("KERAS_REMOTE_PROJECT") or os.environ.get(
11+
"GOOGLE_CLOUD_PROJECT"
12+
)

keras_remote/utils/storage.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
11
"""Cloud Storage operations for keras_remote."""
22

3+
from __future__ import annotations
4+
35
import os
46
import tempfile
57

68
from absl import logging
79
from google.cloud import storage
810

9-
10-
def _get_project():
11-
"""Get project ID from environment or default gcloud config."""
12-
return os.environ.get("KERAS_REMOTE_PROJECT") or os.environ.get(
13-
"GOOGLE_CLOUD_PROJECT"
14-
)
11+
from keras_remote.data import Data
12+
from keras_remote.infra.infra import get_default_project
1513

1614

1715
def upload_artifacts(
18-
bucket_name, job_id, payload_path, context_path, project=None
19-
):
16+
bucket_name: str,
17+
job_id: str,
18+
payload_path: str,
19+
context_path: str,
20+
project: str | None = None,
21+
) -> None:
2022
"""Upload execution artifacts to Cloud Storage.
2123
2224
Args:
@@ -26,7 +28,7 @@ def upload_artifacts(
2628
context_path: Local path to context.zip
2729
project: GCP project ID (optional, uses env vars if not provided)
2830
"""
29-
project = project or _get_project()
31+
project = project or get_default_project()
3032

3133
client = storage.Client(project=project)
3234
bucket = client.bucket(bucket_name)
@@ -55,7 +57,9 @@ def upload_artifacts(
5557
)
5658

5759

58-
def download_result(bucket_name, job_id, project=None):
60+
def download_result(
61+
bucket_name: str, job_id: str, project: str | None = None
62+
) -> str:
5963
"""Download result from Cloud Storage.
6064
6165
Args:
@@ -66,7 +70,7 @@ def download_result(bucket_name, job_id, project=None):
6670
Returns:
6771
Local path to downloaded result file
6872
"""
69-
project = project or _get_project()
73+
project = project or get_default_project()
7074
client = storage.Client(project=project)
7175
bucket = client.bucket(bucket_name)
7276

@@ -80,15 +84,17 @@ def download_result(bucket_name, job_id, project=None):
8084
return local_path
8185

8286

83-
def cleanup_artifacts(bucket_name, job_id, project=None):
87+
def cleanup_artifacts(
88+
bucket_name: str, job_id: str, project: str | None = None
89+
) -> None:
8490
"""Clean up job artifacts from Cloud Storage.
8591
8692
Args:
8793
bucket_name: Name of the GCS bucket
8894
job_id: Unique job identifier
8995
project: GCP project ID (optional, uses env vars if not provided)
9096
"""
91-
project = project or _get_project()
97+
project = project or get_default_project()
9298
client = storage.Client(project=project)
9399
bucket = client.bucket(bucket_name)
94100

@@ -108,7 +114,12 @@ def cleanup_artifacts(bucket_name, job_id, project=None):
108114
)
109115

110116

111-
def upload_data(bucket_name, data, project=None, namespace_prefix="default"):
117+
def upload_data(
118+
bucket_name: str,
119+
data: Data,
120+
project: str | None = None,
121+
namespace_prefix: str = "default",
122+
) -> str:
112123
"""Upload a Data object to GCS with content-based caching.
113124
114125
For GCS Data: returns the original URI (no upload).
@@ -130,7 +141,7 @@ def upload_data(bucket_name, data, project=None, namespace_prefix="default"):
130141
content_hash = data.content_hash()
131142
cache_prefix = f"{namespace_prefix}/data-cache/{content_hash}"
132143

133-
project = project or _get_project()
144+
project = project or get_default_project()
134145
client = storage.Client(project=project)
135146
bucket = client.bucket(bucket_name)
136147

@@ -178,7 +189,7 @@ def upload_data(bucket_name, data, project=None, namespace_prefix="default"):
178189
return f"gs://{bucket_name}/{cache_prefix}"
179190

180191

181-
def _compute_total_size(path):
192+
def _compute_total_size(path: str) -> int:
182193
"""Compute total size in bytes of a file or directory."""
183194
if os.path.isfile(path):
184195
return os.path.getsize(path)
@@ -189,7 +200,9 @@ def _compute_total_size(path):
189200
return total
190201

191202

192-
def _upload_directory(bucket, local_dir, gcs_prefix):
203+
def _upload_directory(
204+
bucket: storage.Bucket, local_dir: str, gcs_prefix: str
205+
) -> None:
193206
"""Upload a local directory to GCS preserving structure."""
194207
for root, _dirs, files in os.walk(local_dir):
195208
for fname in files:

keras_remote/utils/storage_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from absl.testing import absltest, parameterized
1010

1111
from keras_remote.data import Data
12+
from keras_remote.infra.infra import get_default_project
1213
from keras_remote.utils.storage import (
1314
_compute_total_size,
14-
_get_project,
1515
_upload_directory,
1616
cleanup_artifacts,
1717
download_result,
@@ -140,7 +140,7 @@ def test_resolves_project(self, kr_project, gc_project, expected):
140140
if gc_project:
141141
env["GOOGLE_CLOUD_PROJECT"] = gc_project
142142
with mock.patch.dict(os.environ, env, clear=True):
143-
self.assertEqual(_get_project(), expected)
143+
self.assertEqual(get_default_project(), expected)
144144

145145

146146
class TestUploadData(_GcsTestBase):

0 commit comments

Comments
 (0)