Skip to content

Commit 19fae7c

Browse files
olegshaldybinOrbax Authors
authored andcommitted
Add scratch dir to ServingConfig
PiperOrigin-RevId: 878067893
1 parent 7fc5266 commit 19fae7c

File tree

3 files changed

+87
-49
lines changed

3 files changed

+87
-49
lines changed

export/orbax/export/data_processors/tf_data_processor_test.py

Lines changed: 63 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -169,52 +169,72 @@ def func(x, y, z):
169169
)
170170
)
171171

172-
processor.prepare(
173-
available_tensor_specs=(tf.TensorSpec(shape=(2, 3), dtype=tf.float32),
174-
tf.TensorSpec(shape=(2, 3), dtype=tf.float32),
175-
tf.constant(2, dtype=tf.int32)),
176-
bfloat16_options=converter_options,
177-
tf_trackable_resources=[v],
178-
)
179-
180-
self.assertEqual(
181-
processor.input_signature,
182-
(
183-
(
184-
obm.ShloTensorSpec(
185-
shape=(2, 3), dtype=obm.ShloDType.bf16, name='x'
186-
),
187-
obm.ShloTensorSpec(
188-
shape=(2, 3), dtype=obm.ShloDType.bf16, name='y'
189-
),
190-
obm.ShloTensorSpec(
191-
shape=(), dtype=obm.ShloDType.i32, name='z'
192-
),
172+
with self.subTest('without_scratch_dir'):
173+
with self.assertRaisesRegex(
174+
ValueError, 'conversion_base_dir must be provided'
175+
):
176+
processor.prepare(
177+
available_tensor_specs=(
178+
tf.TensorSpec(shape=(2, 3), dtype=tf.float32),
179+
tf.TensorSpec(shape=(2, 3), dtype=tf.float32),
180+
tf.constant(2, dtype=tf.int32),
193181
),
194-
{},
195-
),
196-
)
197-
self.assertEqual(
198-
processor.output_signature,
199-
obm.ShloTensorSpec(
200-
shape=(2, 3), dtype=obm.ShloDType.bf16, name='output_0'
201-
),
202-
)
182+
bfloat16_options=converter_options,
183+
tf_trackable_resources=[v],
184+
)
203185

204-
# Verify that the variables have been converted to bfloat16 too.
205-
model_dir = self.create_tempdir().full_path
206-
tf2obm.save_tf_functions(
207-
model_dir,
208-
{'preprocessor': processor.concrete_function},
209-
trackable_resources=[v],
210-
converter_options=converter_options,
211-
)
186+
with self.subTest('with_scratch_dir'):
187+
processor.prepare(
188+
available_tensor_specs=(
189+
tf.TensorSpec(shape=(2, 3), dtype=tf.float32),
190+
tf.TensorSpec(shape=(2, 3), dtype=tf.float32),
191+
tf.constant(2, dtype=tf.int32),
192+
),
193+
bfloat16_options=converter_options,
194+
scratch_dir=self.create_tempdir().full_path,
195+
tf_trackable_resources=[v],
196+
)
197+
198+
self.assertEqual(
199+
processor.input_signature,
200+
(
201+
(
202+
obm.ShloTensorSpec(
203+
shape=(2, 3), dtype=obm.ShloDType.bf16, name='x'
204+
),
205+
obm.ShloTensorSpec(
206+
shape=(2, 3), dtype=obm.ShloDType.bf16, name='y'
207+
),
208+
obm.ShloTensorSpec(
209+
shape=(), dtype=obm.ShloDType.i32, name='z'
210+
),
211+
),
212+
{},
213+
),
214+
)
215+
self.assertEqual(
216+
processor.output_signature,
217+
obm.ShloTensorSpec(
218+
shape=(2, 3), dtype=obm.ShloDType.bf16, name='output_0'
219+
),
220+
)
212221

213-
saved_model = tf.saved_model.load(os.path.join(model_dir, 'tf_saved_model'))
214-
restored_fn = saved_model.signatures['preprocessor']
222+
# Verify that the variables have been converted to bfloat16 too.
223+
model_dir = self.create_tempdir().full_path
224+
tf2obm.save_tf_functions(
225+
model_dir,
226+
{'preprocessor': processor.concrete_function},
227+
trackable_resources=[v],
228+
converter_options=converter_options,
229+
)
230+
231+
saved_model = tf.saved_model.load(
232+
os.path.join(model_dir, 'tf_saved_model')
233+
)
234+
restored_fn = saved_model.signatures['preprocessor']
215235

216-
self.assertLen(restored_fn.variables, 1)
217-
self.assertEqual(restored_fn.variables[0].dtype, tf.bfloat16)
236+
self.assertLen(restored_fn.variables, 1)
237+
self.assertEqual(restored_fn.variables[0].dtype, tf.bfloat16)
218238

219239
def test_bfloat16_convert_error(self):
220240
processor = tf_data_processor.TfDataProcessor(
@@ -232,6 +252,7 @@ def test_bfloat16_convert_error(self):
232252
scope=converter_options_v2_pb2.BFloat16OptimizationOptions.ALL,
233253
)
234254
),
255+
scratch_dir=self.create_tempdir().full_path,
235256
)
236257

237258
def test_prepare_with_shlo_bf16_inputs(self):

export/orbax/export/obm_configs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,8 @@ class ObmExportOptions:
343343
together for efficiency (a technique called online batching). See
344344
`BatchOptions` above for configuration details. If set to None, batching
345345
is disabled, and requests are processed individually. #
346+
scratch_dir: Scratch directory to be used by OBM if needed.
346347
"""
347348

348349
batch_options: BatchOptions | None = None
350+
scratch_dir: str | None = None

model/orbax/experimental/model/tf2obm/_src/converter.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from collections.abc import Mapping, Sequence
1818
import copy
1919
import os
20-
import tempfile
2120
from typing import Any, Dict, NamedTuple, Tuple
2221

2322
from jax import tree_util as jax_tree_util
@@ -56,9 +55,11 @@ def _is_args_kwargs_pattern(tree: utils.TfSignature) -> bool:
5655
def convert_function(
5756
fn_name: str,
5857
fn: tf.types.experimental.ConcreteFunction,
58+
*,
5959
converter_options: (
6060
converter_options_v2_pb2.ConverterOptionsV2 | None
6161
) = None,
62+
conversion_base_dir: str | None = None,
6263
trackable_resources: Any | None = None,
6364
) -> obm.SerializableFunction:
6465
"""Converts the TF concrete function to an OBM function.
@@ -73,6 +74,8 @@ def convert_function(
7374
converter_options: The converter options to use for the TF SavedModel. If
7475
set, the TF SavedModel will be converted using Inference Converter V2 in
7576
order to get the correct types for the input and output signatures.
77+
conversion_base_dir: The base directory to save the converted TF SavedModel.
78+
The converted model will be deleted afterwards.
7679
trackable_resources: Trackable resources used by the function.
7780
7881
Returns:
@@ -83,8 +86,12 @@ def convert_function(
8386
output_names = _output_names(fn)
8487

8588
if converter_options is not None:
89+
if conversion_base_dir is None:
90+
raise ValueError(
91+
'conversion_base_dir must be provided when converter_options is set.'
92+
)
8693
converterted_signature_def = _get_converted_function_signature_def(
87-
fn_name, fn, trackable_resources, converter_options
94+
conversion_base_dir, fn_name, fn, trackable_resources, converter_options
8895
)
8996
input_signature = _copy_types_from_signature_def(
9097
fn.structured_input_signature,
@@ -510,6 +517,7 @@ def _copy_type(t: Any) -> Any:
510517

511518

512519
def _get_converted_function_signature_def(
520+
conversion_base_dir: str,
513521
fn_name: str,
514522
fn: tf.types.experimental.ConcreteFunction,
515523
trackable_resources: Any,
@@ -518,6 +526,7 @@ def _get_converted_function_signature_def(
518526
"""Saves the function, converts it, returns its SignatureDef.
519527
520528
Args:
529+
conversion_base_dir: The OBM model directory.
521530
fn_name: The name of the function in the SavedModel.
522531
fn: The concrete function to save.
523532
trackable_resources: The trackable resources to save.
@@ -534,16 +543,22 @@ def _get_converted_function_signature_def(
534543
False
535544
)
536545

537-
with tempfile.TemporaryDirectory() as temp_dir:
546+
tmp_model_dir = os.path.join(conversion_base_dir, f'converted_{fn_name}')
547+
tf.io.gfile.makedirs(tmp_model_dir)
548+
549+
try:
538550
save_tf_functions(
539-
temp_dir,
551+
tmp_model_dir,
540552
{fn_name: fn},
541553
trackable_resources=trackable_resources,
542554
converter_options=opts_copy,
543555
)
544556

545-
converted_model_path = os.path.join(temp_dir, OBM_TF_SAVED_MODEL_SUB_DIR)
557+
converted_model_path = os.path.join(
558+
tmp_model_dir, OBM_TF_SAVED_MODEL_SUB_DIR
559+
)
546560
with open(os.path.join(converted_model_path, 'saved_model.pb'), 'rb') as f:
547561
saved_model_proto = saved_model_pb2.SavedModel.FromString(f.read())
548-
549-
return saved_model_proto.meta_graphs[0].signature_def[fn_name]
562+
return saved_model_proto.meta_graphs[0].signature_def[fn_name]
563+
finally:
564+
tf.io.gfile.rmtree(tmp_model_dir)

0 commit comments

Comments
 (0)