1616from collections .abc import Mapping
1717import os
1818import re
19- from typing import Callable , Any
19+ from typing import Any , Callable
2020
2121from orbax .experimental .model import core as obm
2222from orbax .experimental .model .core .python import file_utils
23- from orbax .experimental .model .jd2obm import voxel_asset_map_pb2
23+ from orbax .experimental .model .jd2obm import jd_asset_map_pb2
24+ from orbax .experimental .model .jd2obm import module
2425
2526
26- VOXEL_PROCESSOR_MIME_TYPE = 'application/protobuf; type=voxel.PlanProto'
27- VOXEL_PROCESSOR_VERSION = '0.0.1'
28- DEFAULT_VOXEL_MODULE_FOLDER = 'voxel_module'
29- VOXEL_ASSETS_FOLDER = 'assets'
30- VOXEL_ASSET_MAP_MIME_TYPE = (
31- 'application/protobuf; type=orbax_model_voxel_assets_map.VoxelAssetsMap'
27+ # TODO(b/479875543): Update proto type to be voxel agnostic.
28+ JD_PROCESSOR_MIME_TYPE = 'application/protobuf; type=voxel.PlanProto'
29+ JD_PROCESSOR_VERSION = '0.0.1'
30+ DEFAULT_JD_MODULE_FOLDER = 'jd_module'
31+ JD_ASSETS_FOLDER = 'assets'
32+ JD_ASSET_MAP_MIME_TYPE = (
33+ 'application/protobuf; type=orbax_model_jd_assets_map.JDAssetsMap'
3234)
33- VOXEL_ASSET_MAP_VERSION = '0.0.1'
34- VOXEL_ASSET_MAP_SUPPLEMENTAL_NAME = 'voxel_asset_map '
35+ JD_ASSET_MAP_VERSION = '0.0.1'
36+ JD_ASSET_MAP_SUPPLEMENTAL_NAME = 'jd_asset_map '
3537
3638
37- def voxel_plan_to_obm (
38- voxel_module : Any ,
39+ def jd_plan_to_obm (
40+ jd_module : module . JDModuleBase ,
3941 input_signature : obm .Tree [obm .ShloTensorSpec ],
4042 output_signature : obm .Tree [obm .ShloTensorSpec ],
41- subfolder : str = DEFAULT_VOXEL_MODULE_FOLDER ,
43+ subfolder : str = DEFAULT_JD_MODULE_FOLDER ,
4244) -> obm .SerializableFunction :
43- """Converts a Voxel plan to an `obm.SerializableFunction`.
45+ """Converts a JD plan to an `obm.SerializableFunction`.
4446
4547 Args:
46- voxel_module : The Voxel module to be converted.
47- input_signature: The input signature of the Voxel module.
48- output_signature: The output signature of the Voxel module.
48+ jd_module : The JD module to be converted.
49+ input_signature: The input signature of the JD module.
50+ output_signature: The output signature of the JD module.
4951 subfolder: The name of the subfolder for the converted module.
5052
5153 Returns:
52- An `obm.SerializableFunction` representing the Voxel module.
54+ An `obm.SerializableFunction` representing the JD module.
5355 """
54- plan = voxel_module .export_plan ()
56+ plan = jd_module .export_plan ()
5557 unstructured_data = obm .manifest_pb2 .UnstructuredData (
5658 inlined_bytes = plan .SerializeToString (),
57- mime_type = VOXEL_PROCESSOR_MIME_TYPE ,
58- version = VOXEL_PROCESSOR_VERSION ,
59+ mime_type = JD_PROCESSOR_MIME_TYPE ,
60+ version = JD_PROCESSOR_VERSION ,
5961 )
6062
6163 obm_func = obm .SerializableFunction (
@@ -91,8 +93,8 @@ def _normalize_file_name(file_name: str) -> str:
9193 return f'{ base } { ext } '
9294
9395
94- class _VoxelAssetMapBuilder :
95- """Helper class to build VoxelAssetsMap efficiently."""
96+ class _JDAssetMapBuilder :
97+ """Helper class to build JDAssetsMap efficiently."""
9698
9799 def __init__ (self ):
98100 # Maps unique/sanitized filenames to their original source paths.
@@ -101,11 +103,11 @@ def __init__(self):
101103 # Stores the next available index for a given base filename,
102104 # defaulting to 1 if the base filename hasn't been seen before.
103105 self ._next_index_map : dict [str , int ] = collections .defaultdict (lambda : 1 )
104- self ._voxel_asset_map = voxel_asset_map_pb2 . VoxelAssetsMap ()
106+ self ._jd_asset_map = jd_asset_map_pb2 . JDAssetsMap ()
105107
106108 @property
107- def voxel_asset_map (self ) -> voxel_asset_map_pb2 . VoxelAssetsMap :
108- return self ._voxel_asset_map
109+ def jd_asset_map (self ) -> jd_asset_map_pb2 . JDAssetsMap :
110+ return self ._jd_asset_map
109111
110112 def add_asset (self , source_path : str , subfolder : str ) -> None :
111113 """Adds an asset to the map, resolving name conflicts.
@@ -119,7 +121,7 @@ def add_asset(self, source_path: str, subfolder: str) -> None:
119121 subfolder: The name of the assets subfolder for the saved model.
120122 """
121123 # The asset has been added before, skip.
122- if source_path in self ._voxel_asset_map .assets :
124+ if source_path in self ._jd_asset_map .assets :
123125 return
124126
125127 file_name = os .path .basename (source_path )
@@ -136,108 +138,106 @@ def add_asset(self, source_path: str, subfolder: str) -> None:
136138
137139 # Add the unique file name to the maps.
138140 self ._auxiliary_file_map [unique_file_name ] = source_path
139- self ._voxel_asset_map .assets [source_path ] = os .path .join (
140- subfolder , VOXEL_ASSETS_FOLDER , unique_file_name
141+ self ._jd_asset_map .assets [source_path ] = os .path .join (
142+ subfolder , JD_ASSETS_FOLDER , unique_file_name
141143 )
142144
143145
144- def _get_voxel_asset_map (
145- asset_source_path : set [str ], subfolder : str = DEFAULT_VOXEL_MODULE_FOLDER
146- ) -> voxel_asset_map_pb2 . VoxelAssetsMap :
147- """Gets a VoxelAssetsMap proto for a given set of asset source paths.
146+ def _get_jd_asset_map (
147+ asset_source_path : set [str ], subfolder : str = DEFAULT_JD_MODULE_FOLDER
148+ ) -> jd_asset_map_pb2 . JDAssetsMap :
149+ """Gets a JDAssetsMap proto for a given set of asset source paths.
148150
149- The VoxelAssetsMap proto contains a mapping from original asset paths to
151+ The JDAssetsMap proto contains a mapping from original asset paths to
150152 the new relative paths in the saved model directory.
151153
152154 Args:
153155 asset_source_path: A set of source paths of the assets.
154156 subfolder: The name of the subfolder for the converted module.
155157
156158 Returns:
157- A VoxelAssetsMap proto.
159+ A JDAssetsMap proto.
158160 """
159- builder = _VoxelAssetMapBuilder ()
161+ builder = _JDAssetMapBuilder ()
160162 for source_path in sorted (list (asset_source_path )):
161163 builder .add_asset (source_path , subfolder )
162- return builder .voxel_asset_map
164+ return builder .jd_asset_map
163165
164166
165- def _save_assets (
166- voxel_asset_map : voxel_asset_map_pb2 .VoxelAssetsMap , path : str
167- ) -> None :
168- """Saves asset files based on the provided VoxelAssetsMap.
167+ def _save_assets (jd_asset_map : jd_asset_map_pb2 .JDAssetsMap , path : str ) -> None :
168+ """Saves asset files based on the provided JDAssetsMap.
169169
170- Iterates through the assets in voxel_asset_map and copies each asset from
170+ Iterates through the assets in jd_asset_map and copies each asset from
171171 its source path to destination. The destination path is constructed by joining
172172 `path` with the asset's relative path, and destination directories are
173173 created as needed.
174174
175175 Args:
176- voxel_asset_map : A VoxelAssetsMap proto containing asset mappings.
176+ jd_asset_map : A JDAssetsMap proto containing asset mappings.
177177 path: The base destination directory to save the assets.
178178 """
179- for source_path , dest_relative_path in voxel_asset_map .assets .items ():
179+ for source_path , dest_relative_path in jd_asset_map .assets .items ():
180180 dest_path = os .path .join (path , dest_relative_path )
181181 file_utils .mkdir_p (os .path .dirname (dest_path ))
182182 file_utils .copy (source_path , dest_path )
183183 return
184184
185185
186186def _asset_map_to_obm_supplemental (
187- voxel_asset_map : voxel_asset_map_pb2 . VoxelAssetsMap ,
187+ jd_asset_map : jd_asset_map_pb2 . JDAssetsMap ,
188188) -> obm .GlobalSupplemental :
189- """Converts a VoxelAssetsMap proto to an obm.GlobalSupplemental object.
189+ """Converts a JDAssetsMap proto to an obm.GlobalSupplemental object.
190190
191- Serializes the VoxelAssetsMap to bytes and wraps it in an
191+ Serializes the JDAssetsMap to bytes and wraps it in an
192192 obm.UnstructuredData object, returning it as part of an
193193 obm.GlobalSupplemental object.
194194
195195 Args:
196- voxel_asset_map : A VoxelAssetsMap proto to be converted.
196+ jd_asset_map : A JDAssetsMap proto to be converted.
197197
198198 Returns:
199- An obm.GlobalSupplemental object containing the serialized voxel asset map.
199+ An obm.GlobalSupplemental object containing the serialized jd asset map.
200200 """
201201 return obm .GlobalSupplemental (
202202 data = obm .UnstructuredData (
203- inlined_bytes = voxel_asset_map .SerializeToString (),
204- mime_type = VOXEL_ASSET_MAP_MIME_TYPE ,
205- version = VOXEL_ASSET_MAP_VERSION ,
203+ inlined_bytes = jd_asset_map .SerializeToString (),
204+ mime_type = JD_ASSET_MAP_MIME_TYPE ,
205+ version = JD_ASSET_MAP_VERSION ,
206206 ),
207- save_as = 'voxel_asset_map .pb' ,
207+ save_as = 'jd_asset_map .pb' ,
208208 )
209209
210210
211- def voxel_global_supplemental_closure (
212- voxel_module : Any ,
211+ def jd_global_supplemental_closure (
212+ jd_module : Any ,
213213) -> Callable [[str ], Mapping [str , obm .GlobalSupplemental ]] | None :
214- """Returns a closure for saving Voxel assets and creating supplemental data.
214+ """Returns a closure for saving jd assets and creating supplemental data.
215215
216- This function first generates a VoxelAssetsMap based on asset_source_paths.
216+ This function first generates a JDAssetsMap based on asset_source_paths.
217217 It then returns a closure function. When called, the closure saves the
218218 assets to a specified destination and returns an obm.GlobalSupplemental object
219219 containing the asset map.
220220
221221 Args:
222- voxel_module : A Voxel module instance.
222+ jd_module : A Jax Data module instance.
223223
224224 Returns:
225225 A function that takes the asset destination path string, stores assets in it,
226- and returns a dictionary of one entry, from the Voxel supplemental name to
227- the obm.GlobalSupplemental object encoding the Voxel asset map.
226+ and returns a dictionary of one entry, from the JD supplemental name to
227+ the obm.GlobalSupplemental object encoding the JD asset map.
228228 """
229- asset_source_paths = voxel_module .export_assets ()
229+ asset_source_paths = jd_module .export_assets ()
230230 if not asset_source_paths :
231231 return None
232- voxel_asset_map = _get_voxel_asset_map (asset_source_paths )
232+ jd_asset_map = _get_jd_asset_map (asset_source_paths )
233233
234234 def save_and_create_global_supplemental (
235235 path : str ,
236236 ) -> Mapping [str , obm .GlobalSupplemental ]:
237- _save_assets (voxel_asset_map , path )
237+ _save_assets (jd_asset_map , path )
238238 return {
239- VOXEL_ASSET_MAP_SUPPLEMENTAL_NAME : _asset_map_to_obm_supplemental (
240- voxel_asset_map
239+ JD_ASSET_MAP_SUPPLEMENTAL_NAME : _asset_map_to_obm_supplemental (
240+ jd_asset_map
241241 )
242242 }
243243
@@ -247,6 +247,13 @@ def save_and_create_global_supplemental(
247247# Define `__all__` to explicitly declare the public API of this module.
248248# This controls what `from jd2obm import *` imports and helps linters.
249249__all__ = [
250- 'voxel_plan_to_obm' ,
251- 'voxel_global_supplemental_closure'
250+ 'DEFAULT_JD_MODULE_FOLDER' ,
251+ 'JD_PROCESSOR_MIME_TYPE' ,
252+ 'JD_PROCESSOR_VERSION' ,
253+ 'JD_ASSETS_FOLDER' ,
254+ 'JD_ASSET_MAP_MIME_TYPE' ,
255+ 'JD_ASSET_MAP_VERSION' ,
256+ 'JD_ASSET_MAP_SUPPLEMENTAL_NAME' ,
257+ 'jd_plan_to_obm' ,
258+ 'jd_global_supplemental_closure' ,
252259]
0 commit comments