Skip to content

Commit c0a470b

Browse files
authored
Add profiling option to the fuji running script (#1491)
AXLearn has a jax profiling option, that works by giving `--jax_profiler_port=9999` to the submission/deployment script. This PR adds the option `--trace_steps`, a list of steps that we can give to `fuji-train-perf.py` to get profiles at different steps.
1 parent 92243a8 commit c0a470b

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

.github/container/fuji-train-perf.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,13 @@
127127
help="Final log file.",
128128
)
129129

130+
parser.add_argument(
131+
"--trace_steps",
132+
type=list,
133+
default=None,
134+
help="Steps to trace (e.g. [1, 20, 50]). To profile the training give `--jax_profiler_port 9999` to the script.",
135+
)
136+
130137
parser.add_argument("--world_size", type=int, help="Total number of GPUs")
131138

132139

@@ -245,6 +252,7 @@ def main(parsed_args):
245252
save_checkpoint_steps = parsed_args.save_checkpoint_steps
246253
write_summary_steps = parsed_args.write_summary_steps
247254
output_log_file = parsed_args.output_log_file
255+
trace_steps = parsed_args.trace_steps
248256
world_size = parsed_args.world_size
249257

250258
print(
@@ -306,6 +314,10 @@ def main(parsed_args):
306314
trainer_config.checkpointer.save_policy.n = save_checkpoint_steps
307315
trainer_config.checkpointer.keep_every_n_steps = save_checkpoint_steps
308316
trainer_config.summary_writer.write_every_n_steps = write_summary_steps
317+
318+
if trace_steps is not None:
319+
trainer_config.start_trace_steps = trace_steps
320+
309321
# Call AXLearn Jax setup
310322
launch.setup()
311323
# Setup the config

0 commit comments

Comments
 (0)