Skip to content

Commit 67b9025

Browse files
sagunbOrbax Authors
authored andcommitted
Internal only change
PiperOrigin-RevId: 866014308
1 parent e236a32 commit 67b9025

File tree

7 files changed

+168
-162
lines changed

7 files changed

+168
-162
lines changed

model/orbax/experimental/model/jd2obm/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,5 @@
1414

1515
from orbax.experimental.model.jd2obm.main_lib import *
1616
from orbax.experimental.model.jd2obm.utils import *
17-
from orbax.experimental.model.jd2obm import voxel_asset_map_pb2
18-
from orbax.experimental.model.jd2obm.voxel_mock import *
17+
from orbax.experimental.model.jd2obm import jd_asset_map_pb2
18+
from orbax.experimental.model.jd2obm.module import *

model/orbax/experimental/model/jd2obm/voxel_asset_map.proto renamed to model/orbax/experimental/model/jd2obm/jd_asset_map.proto

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
syntax = "proto3";
22

3-
package orbax_model_voxel_assets_map;
3+
package orbax_model_jd_assets_map;
44

55
option java_multiple_files = true;
66

77
// Defines a map from original asset paths to their new locations within the
8-
// saved model. This is used by the Voxel plan loader to correctly locate and
8+
// saved model. This is used by the Jax Data to correctly locate and
99
// load assets.
10-
message VoxelAssetsMap {
11-
// Key: The original path of an asset.
10+
message JDAssetsMap {
11+
// Key: The original path (or name) of an asset.
1212
// Value: The relative path of the asset within the Orbax saved model folder.
1313
map<string, string> assets = 1;
1414
}

model/orbax/experimental/model/jd2obm/main_lib.py

Lines changed: 74 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -16,46 +16,48 @@
1616
from collections.abc import Mapping
1717
import os
1818
import re
19-
from typing import Callable, Any
19+
from typing import Any, Callable
2020

2121
from orbax.experimental.model import core as obm
2222
from orbax.experimental.model.core.python import file_utils
23-
from orbax.experimental.model.jd2obm import voxel_asset_map_pb2
23+
from orbax.experimental.model.jd2obm import jd_asset_map_pb2
24+
from orbax.experimental.model.jd2obm import module
2425

2526

26-
VOXEL_PROCESSOR_MIME_TYPE = 'application/protobuf; type=voxel.PlanProto'
27-
VOXEL_PROCESSOR_VERSION = '0.0.1'
28-
DEFAULT_VOXEL_MODULE_FOLDER = 'voxel_module'
29-
VOXEL_ASSETS_FOLDER = 'assets'
30-
VOXEL_ASSET_MAP_MIME_TYPE = (
31-
'application/protobuf; type=orbax_model_voxel_assets_map.VoxelAssetsMap'
27+
# TODO(b/479875543): Update proto type to be voxel agnostic.
28+
JD_PROCESSOR_MIME_TYPE = 'application/protobuf; type=voxel.PlanProto'
29+
JD_PROCESSOR_VERSION = '0.0.1'
30+
DEFAULT_JD_MODULE_FOLDER = 'jd_module'
31+
JD_ASSETS_FOLDER = 'assets'
32+
JD_ASSET_MAP_MIME_TYPE = (
33+
'application/protobuf; type=orbax_model_jd_assets_map.JDAssetsMap'
3234
)
33-
VOXEL_ASSET_MAP_VERSION = '0.0.1'
34-
VOXEL_ASSET_MAP_SUPPLEMENTAL_NAME = 'voxel_asset_map'
35+
JD_ASSET_MAP_VERSION = '0.0.1'
36+
JD_ASSET_MAP_SUPPLEMENTAL_NAME = 'jd_asset_map'
3537

3638

37-
def voxel_plan_to_obm(
38-
voxel_module: Any,
39+
def jd_plan_to_obm(
40+
jd_module: module.JDModuleBase,
3941
input_signature: obm.Tree[obm.ShloTensorSpec],
4042
output_signature: obm.Tree[obm.ShloTensorSpec],
41-
subfolder: str = DEFAULT_VOXEL_MODULE_FOLDER,
43+
subfolder: str = DEFAULT_JD_MODULE_FOLDER,
4244
) -> obm.SerializableFunction:
43-
"""Converts a Voxel plan to an `obm.SerializableFunction`.
45+
"""Converts a JD plan to an `obm.SerializableFunction`.
4446
4547
Args:
46-
voxel_module: The Voxel module to be converted.
47-
input_signature: The input signature of the Voxel module.
48-
output_signature: The output signature of the Voxel module.
48+
jd_module: The JD module to be converted.
49+
input_signature: The input signature of the JD module.
50+
output_signature: The output signature of the JD module.
4951
subfolder: The name of the subfolder for the converted module.
5052
5153
Returns:
52-
An `obm.SerializableFunction` representing the Voxel module.
54+
An `obm.SerializableFunction` representing the JD module.
5355
"""
54-
plan = voxel_module.export_plan()
56+
plan = jd_module.export_plan()
5557
unstructured_data = obm.manifest_pb2.UnstructuredData(
5658
inlined_bytes=plan.SerializeToString(),
57-
mime_type=VOXEL_PROCESSOR_MIME_TYPE,
58-
version=VOXEL_PROCESSOR_VERSION,
59+
mime_type=JD_PROCESSOR_MIME_TYPE,
60+
version=JD_PROCESSOR_VERSION,
5961
)
6062

6163
obm_func = obm.SerializableFunction(
@@ -91,8 +93,8 @@ def _normalize_file_name(file_name: str) -> str:
9193
return f'{base}{ext}'
9294

9395

94-
class _VoxelAssetMapBuilder:
95-
"""Helper class to build VoxelAssetsMap efficiently."""
96+
class _JDAssetMapBuilder:
97+
"""Helper class to build JDAssetsMap efficiently."""
9698

9799
def __init__(self):
98100
# Maps unique/sanitized filenames to their original source paths.
@@ -101,11 +103,11 @@ def __init__(self):
101103
# Stores the next available index for a given base filename,
102104
# defaulting to 1 if the base filename hasn't been seen before.
103105
self._next_index_map: dict[str, int] = collections.defaultdict(lambda: 1)
104-
self._voxel_asset_map = voxel_asset_map_pb2.VoxelAssetsMap()
106+
self._jd_asset_map = jd_asset_map_pb2.JDAssetsMap()
105107

106108
@property
107-
def voxel_asset_map(self) -> voxel_asset_map_pb2.VoxelAssetsMap:
108-
return self._voxel_asset_map
109+
def jd_asset_map(self) -> jd_asset_map_pb2.JDAssetsMap:
110+
return self._jd_asset_map
109111

110112
def add_asset(self, source_path: str, subfolder: str) -> None:
111113
"""Adds an asset to the map, resolving name conflicts.
@@ -119,7 +121,7 @@ def add_asset(self, source_path: str, subfolder: str) -> None:
119121
subfolder: The name of the assets subfolder for the saved model.
120122
"""
121123
# The asset has been added before, skip.
122-
if source_path in self._voxel_asset_map.assets:
124+
if source_path in self._jd_asset_map.assets:
123125
return
124126

125127
file_name = os.path.basename(source_path)
@@ -136,108 +138,106 @@ def add_asset(self, source_path: str, subfolder: str) -> None:
136138

137139
# Add the unique file name to the maps.
138140
self._auxiliary_file_map[unique_file_name] = source_path
139-
self._voxel_asset_map.assets[source_path] = os.path.join(
140-
subfolder, VOXEL_ASSETS_FOLDER, unique_file_name
141+
self._jd_asset_map.assets[source_path] = os.path.join(
142+
subfolder, JD_ASSETS_FOLDER, unique_file_name
141143
)
142144

143145

144-
def _get_voxel_asset_map(
145-
asset_source_path: set[str], subfolder: str = DEFAULT_VOXEL_MODULE_FOLDER
146-
) -> voxel_asset_map_pb2.VoxelAssetsMap:
147-
"""Gets a VoxelAssetsMap proto for a given set of asset source paths.
146+
def _get_jd_asset_map(
147+
asset_source_path: set[str], subfolder: str = DEFAULT_JD_MODULE_FOLDER
148+
) -> jd_asset_map_pb2.JDAssetsMap:
149+
"""Gets a JDAssetsMap proto for a given set of asset source paths.
148150
149-
The VoxelAssetsMap proto contains a mapping from original asset paths to
151+
The JDAssetsMap proto contains a mapping from original asset paths to
150152
the new relative paths in the saved model directory.
151153
152154
Args:
153155
asset_source_path: A set of source paths of the assets.
154156
subfolder: The name of the subfolder for the converted module.
155157
156158
Returns:
157-
A VoxelAssetsMap proto.
159+
A JDAssetsMap proto.
158160
"""
159-
builder = _VoxelAssetMapBuilder()
161+
builder = _JDAssetMapBuilder()
160162
for source_path in sorted(list(asset_source_path)):
161163
builder.add_asset(source_path, subfolder)
162-
return builder.voxel_asset_map
164+
return builder.jd_asset_map
163165

164166

165-
def _save_assets(
166-
voxel_asset_map: voxel_asset_map_pb2.VoxelAssetsMap, path: str
167-
) -> None:
168-
"""Saves asset files based on the provided VoxelAssetsMap.
167+
def _save_assets(jd_asset_map: jd_asset_map_pb2.JDAssetsMap, path: str) -> None:
168+
"""Saves asset files based on the provided JDAssetsMap.
169169
170-
Iterates through the assets in voxel_asset_map and copies each asset from
170+
Iterates through the assets in jd_asset_map and copies each asset from
171171
its source path to destination. The destination path is constructed by joining
172172
`path` with the asset's relative path, and destination directories are
173173
created as needed.
174174
175175
Args:
176-
voxel_asset_map: A VoxelAssetsMap proto containing asset mappings.
176+
jd_asset_map: A JDAssetsMap proto containing asset mappings.
177177
path: The base destination directory to save the assets.
178178
"""
179-
for source_path, dest_relative_path in voxel_asset_map.assets.items():
179+
for source_path, dest_relative_path in jd_asset_map.assets.items():
180180
dest_path = os.path.join(path, dest_relative_path)
181181
file_utils.mkdir_p(os.path.dirname(dest_path))
182182
file_utils.copy(source_path, dest_path)
183183
return
184184

185185

186186
def _asset_map_to_obm_supplemental(
187-
voxel_asset_map: voxel_asset_map_pb2.VoxelAssetsMap,
187+
jd_asset_map: jd_asset_map_pb2.JDAssetsMap,
188188
) -> obm.GlobalSupplemental:
189-
"""Converts a VoxelAssetsMap proto to an obm.GlobalSupplemental object.
189+
"""Converts a JDAssetsMap proto to an obm.GlobalSupplemental object.
190190
191-
Serializes the VoxelAssetsMap to bytes and wraps it in an
191+
Serializes the JDAssetsMap to bytes and wraps it in an
192192
obm.UnstructuredData object, returning it as part of an
193193
obm.GlobalSupplemental object.
194194
195195
Args:
196-
voxel_asset_map: A VoxelAssetsMap proto to be converted.
196+
jd_asset_map: A JDAssetsMap proto to be converted.
197197
198198
Returns:
199-
An obm.GlobalSupplemental object containing the serialized voxel asset map.
199+
An obm.GlobalSupplemental object containing the serialized jd asset map.
200200
"""
201201
return obm.GlobalSupplemental(
202202
data=obm.UnstructuredData(
203-
inlined_bytes=voxel_asset_map.SerializeToString(),
204-
mime_type=VOXEL_ASSET_MAP_MIME_TYPE,
205-
version=VOXEL_ASSET_MAP_VERSION,
203+
inlined_bytes=jd_asset_map.SerializeToString(),
204+
mime_type=JD_ASSET_MAP_MIME_TYPE,
205+
version=JD_ASSET_MAP_VERSION,
206206
),
207-
save_as='voxel_asset_map.pb',
207+
save_as='jd_asset_map.pb',
208208
)
209209

210210

211-
def voxel_global_supplemental_closure(
212-
voxel_module: Any,
211+
def jd_global_supplemental_closure(
212+
jd_module: Any,
213213
) -> Callable[[str], Mapping[str, obm.GlobalSupplemental]] | None:
214-
"""Returns a closure for saving Voxel assets and creating supplemental data.
214+
"""Returns a closure for saving jd assets and creating supplemental data.
215215
216-
This function first generates a VoxelAssetsMap based on asset_source_paths.
216+
This function first generates a JDAssetsMap based on asset_source_paths.
217217
It then returns a closure function. When called, the closure saves the
218218
assets to a specified destination and returns an obm.GlobalSupplemental object
219219
containing the asset map.
220220
221221
Args:
222-
voxel_module: A Voxel module instance.
222+
jd_module: A Jax Data module instance.
223223
224224
Returns:
225225
A function that takes the asset destination path string, stores assets in it,
226-
and returns a dictionary of one entry, from the Voxel supplemental name to
227-
the obm.GlobalSupplemental object encoding the Voxel asset map.
226+
and returns a dictionary of one entry, from the JD supplemental name to
227+
the obm.GlobalSupplemental object encoding the JD asset map.
228228
"""
229-
asset_source_paths = voxel_module.export_assets()
229+
asset_source_paths = jd_module.export_assets()
230230
if not asset_source_paths:
231231
return None
232-
voxel_asset_map = _get_voxel_asset_map(asset_source_paths)
232+
jd_asset_map = _get_jd_asset_map(asset_source_paths)
233233

234234
def save_and_create_global_supplemental(
235235
path: str,
236236
) -> Mapping[str, obm.GlobalSupplemental]:
237-
_save_assets(voxel_asset_map, path)
237+
_save_assets(jd_asset_map, path)
238238
return {
239-
VOXEL_ASSET_MAP_SUPPLEMENTAL_NAME: _asset_map_to_obm_supplemental(
240-
voxel_asset_map
239+
JD_ASSET_MAP_SUPPLEMENTAL_NAME: _asset_map_to_obm_supplemental(
240+
jd_asset_map
241241
)
242242
}
243243

@@ -247,6 +247,13 @@ def save_and_create_global_supplemental(
247247
# Define `__all__` to explicitly declare the public API of this module.
248248
# This controls what `from jd2obm import *` imports and helps linters.
249249
__all__ = [
250-
'voxel_plan_to_obm',
251-
'voxel_global_supplemental_closure'
250+
'DEFAULT_JD_MODULE_FOLDER',
251+
'JD_PROCESSOR_MIME_TYPE',
252+
'JD_PROCESSOR_VERSION',
253+
'JD_ASSETS_FOLDER',
254+
'JD_ASSET_MAP_MIME_TYPE',
255+
'JD_ASSET_MAP_VERSION',
256+
'JD_ASSET_MAP_SUPPLEMENTAL_NAME',
257+
'jd_plan_to_obm',
258+
'jd_global_supplemental_closure',
252259
]

0 commit comments

Comments
 (0)