Skip to content

Commit cacfd9f

Browse files
author
The paxml Authors
committed
Add + improve Pax status logging.
PiperOrigin-RevId: 534090287
1 parent 75c674b commit cacfd9f

File tree

5 files changed

+48
-23
lines changed

5 files changed

+48
-23
lines changed

paxml/checkpoint_creators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,7 @@ def _create_checkpointer(
561561
tensorstore_use_ocdbt: bool = False,
562562
) -> checkpoints.TrainingCheckpointer:
563563
"""Creates a checkpoint manager."""
564+
logging.info('[PAX STATUS]: Creating checkpointer.')
564565
checkpoint_dir = _make_checkpoint_dir(job_log_dir)
565566
train_p = task_p.train
566567
max_to_keep = train_p.save_max_to_keep

paxml/executors.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
from praxis import base_layer
3939
from praxis import pax_fiddle
4040
from praxis import py_utils
41-
from praxis import pytypes
4241
import tensorflow.compat.v2 as tf
4342

4443
from paxml import checkpoints # mapped to internal
@@ -161,6 +160,7 @@ def _maybe_create_train_input(
161160
passed to checkpointer.get_model_states(). If set, the checkpointer will
162161
restore its states from checkpoint.
163162
"""
163+
logging.info('[PAX STATUS]: Instantiating train input pipeline.')
164164
if not task_p.train.enable_input_checkpointing:
165165
_maybe_update_latest_model_step(train_input_p, step, task_p)
166166
train_input = instantiate(train_input_p)
@@ -219,6 +219,7 @@ def setup(
219219
'No training input specs available, while enabling '
220220
'`task_p.train.enforce_input_specs` requires it.'
221221
)
222+
logging.info('[PAX STATUS]: Setting up partitioner')
222223
partitioner.setup(
223224
jax_task,
224225
root_prng_key,
@@ -291,6 +292,7 @@ def _create_decode_programs(self, decode_input_ps):
291292
return decode_programs
292293

293294
def start(self):
295+
logging.info('Starting executor.')
294296
is_vars_replicated = self._task.model.ici_mesh_shape is None
295297
_train_and_evaluate_common(
296298
task=self._task,
@@ -313,9 +315,11 @@ def start(self):
313315
)
314316

315317
# Shutdown the programs and run necessary cleanup.
318+
logging.info('[PAX STATUS]: Shutting down executor.')
316319
self._train_program.shutdown()
317320
for program in self._eval_programs:
318321
program.shutdown()
322+
logging.info('[PAX STATUS]: Executor shutdown complete.')
319323

320324

321325
def _get_partition_decode_once_fn(
@@ -454,7 +458,7 @@ def _train_and_evaluate_common(
454458
f' number {initial_global_step} mismatch.'
455459
)
456460

457-
logging.info('Training loop starting...')
461+
logging.info('[PAX STATUS]: Starting training loop.')
458462
with _DecodeSummaryWriters(
459463
job_log_dir, decode_input_names
460464
) as decode_summary_writers:
@@ -497,7 +501,7 @@ def _train_and_evaluate_common(
497501
gc.collect()
498502
gc.freeze()
499503
while True:
500-
logging.debug('step=`%d`: Beginning', step_i)
504+
logging.debug('[PAX STATUS]: Beginning step `%d`.', step_i)
501505
checkpointer.save_if_needed(
502506
step_i,
503507
partitioned_train_state,
@@ -538,13 +542,15 @@ def _train_and_evaluate_common(
538542
train_p.eval_interval_steps
539543
and step_i % train_p.eval_interval_steps == 0
540544
):
541-
logging.debug(' Starting eval_step().')
545+
logging.debug('[PAX STATUS]: Starting eval_step().')
542546
eval_partitioned_train_state = programs.get_eval_train_state(
543547
task, partitioned_train_state
544548
)
545549
# If we have eval test then also evaluate on test.
546550
if eval_programs:
547-
logging.debug(' Performing eval_step() runs on test splits.')
551+
logging.debug(
552+
'[PAX STATUS]: Performing eval_step() runs on test splits.'
553+
)
548554
with py_utils.timeit() as eval_period:
549555
eval_metrics_list, eval_scoring_metrics_list, num_eval_steps = (
550556
eval_lib.run_eval_loop_over_test_splits(
@@ -566,7 +572,8 @@ def _train_and_evaluate_common(
566572
input_names=[prog.eval_input.name for prog in eval_programs],
567573
)
568574
logging.debug(
569-
' Completed eval_step() runs on test splits in %f seconds.',
575+
'[PAX STATUS]: Completed eval_step() runs on test splits in %f'
576+
' seconds.',
570577
eval_period.elapsed,
571578
)
572579

@@ -586,7 +593,9 @@ def _train_and_evaluate_common(
586593
decode_partitioned_train_state = tasks_lib.extract_ema(
587594
partitioned_train_state
588595
)
589-
logging.debug(' Performing decode_once_fn() with ema states.')
596+
logging.debug(
597+
'[PAX STATUS]: Performing decode_once_fn() with EMA states.'
598+
)
590599
else:
591600
decode_partitioned_train_state = partitioned_train_state
592601
decode_metrics = decode_once_fn(
@@ -595,8 +604,7 @@ def _train_and_evaluate_common(
595604
jax.monitoring.record_event_duration_secs(
596605
'/jax/pax/train/interleaved_decode_duration_sec',
597606
decode_period.elapsed)
598-
599-
logging.debug('step=`%d`: End', step_i - 1)
607+
logging.debug('[PAX STATUS]: Step `%d` completed.', step_i - 1)
600608

601609
if early_stopping_fn is not None:
602610
if tuning_lib.should_early_stop(
@@ -633,7 +641,8 @@ def _train_and_evaluate_common(
633641
)
634642
break
635643
gc.unfreeze()
636-
# Save checkpoint for the last step.
644+
645+
logging.info('[PAX STATUS]: Saving checkpoint for final step.')
637646
checkpointer.save_final(
638647
step_i,
639648
partitioned_train_state=partitioned_train_state,
@@ -643,3 +652,4 @@ def _train_and_evaluate_common(
643652
)
644653

645654
checkpointer.wait_until_finished()
655+
logging.info('[PAX STATUS]: Final checkpoint saved.')

paxml/partitioning.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,8 +347,10 @@ def setup(
347347
)
348348

349349
if train_inputs_shape_dtype:
350+
logging.info('[PAX STATUS]: Getting input shapes from spec.')
350351
self._train_inputs_shape_dtype = train_inputs_shape_dtype
351352
else:
353+
logging.info('[PAX STATUS]: Getting input shapes from first batch.')
352354
self._train_inputs_shape_dtype = self._get_train_inputs_shape_dtype(
353355
train_input_pipeline
354356
)

paxml/programs.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def get_eval_train_state(task: tasks_lib.SingleTask, state: TrainState):
7676
'learner does not seem to have ema enabled'
7777
)
7878
eval_state = tasks_lib.extract_ema(state).to_eval_state()
79-
logging.debug(' Converted train state to eval with ema state.')
79+
logging.info('[PAX STATUS]: Converted train state to eval with EMA state.')
8080
else:
8181
eval_state = state.to_eval_state()
8282
return eval_state
@@ -207,6 +207,7 @@ def setup(
207207
eval_prng_seed: PRNGKey,
208208
init_step: int,
209209
) -> None:
210+
logging.info('[PAX STATUS]: Setting up BaseTrainProgram.')
210211
self._task = task
211212
self._train_input = train_input
212213
self._partitioner = partitioner
@@ -263,7 +264,7 @@ def should_run(self, state: TrainState, step: int) -> bool:
263264
# correspondingly.
264265
def run(self, state: TrainState, step: int) -> ProgramOutput:
265266
train_p = self._task.train
266-
logging.debug(' Retrieving inputs.')
267+
logging.debug('[PAX STATUS]: Retrieving inputs.')
267268

268269
model_inputs = self._train_input.get_next_padded()
269270

@@ -276,7 +277,7 @@ def run(self, state: TrainState, step: int) -> ProgramOutput:
276277
model_inputs, ## First two args can be consolidated
277278
self.train_input_partition_spec(model_inputs),
278279
)
279-
logging.debug(' Retrieved inputs.')
280+
logging.debug('[PAX STATUS]: Retrieved inputs.')
280281

281282
# Waits if it reaches max inflight steps. We do this after retrieving the
282283
# inputs to maximize efficiency.
@@ -287,7 +288,7 @@ def run(self, state: TrainState, step: int) -> ProgramOutput:
287288
if do_profile and step - self._initial_step == profiler_capture_step:
288289
self._profiler.capture_async()
289290

290-
logging.debug(' Performing train_step().')
291+
logging.debug('[PAX STATUS]: Performing train_step().')
291292
with jax.profiler.StepTraceAnnotation('train', step_num=step):
292293
with py_utils.timeit() as train_period:
293294
new_step, new_state, train_outputs = self.train_step(
@@ -297,21 +298,21 @@ def run(self, state: TrainState, step: int) -> ProgramOutput:
297298
model_inputs,
298299
self._train_unpadded_global_batch_size,
299300
)
300-
del state # Unused anymore.
301+
del state # Unused.
301302
jax.monitoring.record_event_duration_secs(
302303
'/jax/pax/train/duration_sec', train_period.elapsed
303304
)
304305
logging.debug(
305-
' Completed train_step() in %f seconds.', train_period.elapsed
306+
'[PAX STATUS]: train_step() took %f seconds.', train_period.elapsed
306307
)
307308
self._pending_train_losses.add_computation(train_outputs.loss)
308309
if step == self._initial_step:
309310
self._first_step_completion_time = time.time()
310311

311312
if do_profile and step - self._initial_step < profiler_capture_step:
312313
self._profiler.update_step_moving_mean(train_period.elapsed)
314+
logging.debug('[PAX STATUS]: Writing summaries (attempt).')
313315
steps_per_sec = self._maybe_write_summaries(step, new_step, train_outputs)
314-
logging.debug(' Writing summaries (attempt).')
315316

316317
# Run eval at regular step interval.
317318
# While the eval ones below are post-model weight updates, hence we use the
@@ -402,7 +403,7 @@ def _maybe_write_summaries(
402403
per_example_out=train_outputs.per_example_out,
403404
steps_per_sec=steps_per_sec,
404405
)
405-
logging.debug(' Wrote summaries (attempted).')
406+
logging.debug('[PAX STATUS]: Wrote summaries (attempted).')
406407
return steps_per_sec
407408

408409
def _compute_steps_per_sec(self, step: int):
@@ -474,7 +475,7 @@ def _maybe_run_eval_train(self, new_state: TrainState, new_step: int):
474475
if self._eval_train_summary_handler.process(
475476
new_step, loss, weighted_scalars, summary_tensors
476477
):
477-
logging.debug(' Wrote eval summaries.')
478+
logging.debug('[PAX STATUS]: Wrote eval summaries.')
478479
eval_train_metrics = metric_utils.as_float_dict(weighted_scalars)
479480
return eval_train_metrics
480481

@@ -680,7 +681,9 @@ def setup(
680681

681682
# Creates the eval input pipeline.
682683
self._input_p = self._partitioner.preprocess_input_config(self._input_p)
683-
logging.debug('Initializing eval_input pipeline : %s', self._input_p)
684+
logging.info(
685+
'[PAX STATUS]: Initializing eval_input pipeline : %s', self._input_p
686+
)
684687
self._eval_input_pipeline = instantiate(self._input_p)
685688
self._name = self.eval_input.name
686689
self._eval_unpadded_global_batch_size = (

paxml/train.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,16 +136,19 @@ def train_and_evaluate(
136136
on-demand checkpoint due to preemption.
137137
"""
138138
jax.monitoring.record_event('/jax/pax/train_and_evaluate/beacon')
139+
logging.info('[PAX STATUS] Starting `train_and_evaluate`')
139140
task_p = experiment_config.task()
140141
task_p = typing.cast(pax_fiddle.Config[tasks_lib.SingleTask], task_p)
141142

142143
# in case the user passed in a string dtype, convert it to an actual dtype
143144
task_p.model.fprop_dtype = jnp.dtype(task_p.model.fprop_dtype)
144145

146+
logging.info('[PAX STATUS] Obtaining and initializing datasets.')
145147
input_p = experiment_config.datasets()
146148
for inp in input_p:
147149
if not isinstance(
148-
inp, (base_input.BaseInput.HParams, base_input.DistributedInputHParams)
150+
inp,
151+
(base_input.BaseInput.HParams, base_input.DistributedInputHParams),
149152
):
150153
raise ValueError(
151154
f'Expecting BaseInput.HParams from datasets(), got: {inp.ToText()}'
@@ -156,6 +159,7 @@ def train_and_evaluate(
156159
f'Expecting exactly one training split. Got `{len(train_input_p)}`.'
157160
)
158161
train_input_p = train_input_p[0]
162+
logging.info('[PAX STATUS]: Done initializing dataset objects')
159163

160164
logging.info('train_input_p:')
161165
for line in base_hyperparams.nested_struct_to_text(
@@ -166,6 +170,7 @@ def train_and_evaluate(
166170
for line in base_hyperparams.nested_struct_to_text(task_p).splitlines(): # pytype: disable=attribute-error
167171
logging.info(' %s', line)
168172

173+
logging.info('[PAX STATUS]: Initializing decoder')
169174
if (
170175
run_decode
171176
and task_p.train.decode_interval_steps is not None
@@ -198,6 +203,7 @@ def train_and_evaluate(
198203
)
199204

200205
# Creates the task.
206+
logging.info('[PAX STATUS]: Creating task')
201207
jax_task = instantiate(task_p)
202208
if jax_task.early_stopping_fn is not None:
203209
if early_stopping_fn is None:
@@ -208,15 +214,16 @@ def train_and_evaluate(
208214
'train_and_evel function parameter.'
209215
)
210216

217+
logging.info('[PAX STATUS]: Initializing partitioner')
211218
# Creates the partitioner, which will be set up later.
212219
partitioner = experiment_config.partitioner()
213220
if not partitioner:
214221
# For the input pipeline on the Pathways client, the inputs are numpy
215222
# arrays. We rely on the Pathways to transfer the inputs, since
216223
# jax.device_put() has a larger performance overhead.
217224
reshard_inputs = (
218-
checkpointer.checkpoint_type != CheckpointType.PERSISTENCE or
219-
train_input_p.experimental_remote_input
225+
checkpointer.checkpoint_type != CheckpointType.PERSISTENCE
226+
or train_input_p.experimental_remote_input
220227
)
221228
partitioner = partitioning.create_partitioner(
222229
jax_task,
@@ -235,9 +242,11 @@ def train_and_evaluate(
235242
eval_programs = experiment_config.eval_programs()
236243

237244
# Creates the executor and run the training pipeline.
245+
logging.info('[PAX STATUS]: Creating executor.')
238246
executor = experiment_config.executor()
239247
if not executor:
240248
executor = executors.DefaultExecutor()
249+
logging.info('[PAX STATUS]: Setting up executor.')
241250
with partitioner.global_mesh or contextlib.nullcontext():
242251
executor.setup(
243252
jax_task,

0 commit comments

Comments
 (0)