Skip to content

Add logging for dataset input and checkpoint saving #1190

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

linxiulei
Copy link

@linxiulei linxiulei commented May 18, 2025

This PR adds:

  • logging utility class to log when an operation's execution time exceeds a specified threshold.
  • use mentioned utility class to around dataset input fetch (per iteration) and checkpoint saving to allow us to observe execution in poor performance

@linxiulei linxiulei requested review from ruomingp, markblee and a team as code owners May 18, 2025 20:15
@@ -1091,8 +1097,9 @@ def _run_step(
train_summaries=outputs["summaries"], force_runs=force_run_evals
)

# Checkpointer policy will decide if we should save.
self.save_checkpoint(evaler_summaries=evaler_summaries)
with LoggingContext("save_checkpoint", timeout=180, threshold=120):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this configurable?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but I am thinking which way is better to make it configurable, env var or flags?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The typical convention we use is that if we want to make the child configurable, then the parent config should contain a config for the child.

import time


class LoggingContext:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain the difference of this from the other monitor functionality that exists like GPUMonitor, etc? If they serve overlapping purposes, should they be unified?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please clarify what other monitor functionality we already have? If they overlap, we can unify them

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TfSummaryMonitor, GPUMonitor, GoodputMonitor. I might be forgetting some additional ones.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I checked those functionality and believe there don't overlap because this PR is mostly for logging specific operations' execution time (the utility function allows you to monitor any operations you want) and those functionality you mentioned is to monitor with specific purposes (goodputs, GPU hardware status etc)

self.save_checkpoint(evaler_summaries=evaler_summaries)
with LoggingContext("save_checkpoint", timeout=180, threshold=120):
# Checkpointer policy will decide if we should save.
self.save_checkpoint(evaler_summaries=evaler_summaries)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since checkpoint syncing is asynchronous, do we expect to see timeouts inside this block?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, IIUC checkpoint saving is asynchronous if the previous saving is complete or it may stuck in https://github.com/apple/axlearn/blob/main/axlearn/common/array_serialization.py#L554

@@ -581,7 +582,12 @@ def run(
output = None
stop_trace_step = None

input_iterator = self.input.batches(self._input_iter)
input_iterator = LoggingIterator(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it is already recorded in the Goodput, I wonder if we need a separate logging here.

If I am reading it correctly, this LoggingIterator creates one timer thread every time next is called, which seems to be adding a lot of thread creation/cancellation overhead to the main process

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming you mean this code

                    self._maybe_record_event(measurement.Event.START_DATA_LOADING)
                    try:
                        input_batch = next(input_iterator)
                        self._maybe_record_event(measurement.Event.END_DATA_LOADING)

If the next() somehow hangs and then the whole process aborts, would Goodput measurement help catch that?

My understanding is this iterator invokes next() every a few seconds depending on the training scale, so creating and cleaning a thread for every a few seconds should have only very minimal overhead (empirically one thread takes 10 to 100 usec to create).

However, if you still have concerns on overhead, we might create a reusable thread for timer purpose.

Comment on lines 11 to 15
"""A context manager for monitoring the execution time of operations and
logging warnings based on defined timeouts and thresholds.

This context manager can be used to:
- Log a warning if an operation exceeds a specified timeout while still
running.
- Log a warning if an operation's total execution time exceeds a specified
threshold.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear that we need logging-upon-timeout, given the threading concerns raised by @Ethanlm. I wonder if a simpler semantic would be sufficient:

  • LoggingContext always wait until the op finishes and log a warning if the execution time exceeds the threshold;
  • We rely on the watchdog
    watchdog_timeout_seconds: Optional[float] = None
    to tell us where the threads are stuck

WDYT?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with starting something simpler so I changed this PR to only implement logging upon op completion.

However, I feel logging-upon-timeout is important because program hangs are difficult to debug by just looking through tracebacks of so many workers printed by watchdog timeout. A clear logging message may help identify workers that cause the hangs.

linxiulei added 2 commits May 22, 2025 15:58
Introduce logging utilities designed to proactively identify and flag
operations exceeding performance expectations.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants