-
Notifications
You must be signed in to change notification settings - Fork 23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
checkpointing: use CheckpointTransport abstraction #81
Conversation
e1a59f0
to
0105685
Compare
0105685
to
0e29ef9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good! Had some comments/questions but partly they are just to clarify my understanding of the changes
@@ -104,6 +104,7 @@ def __init__( | |||
port: Optional[int] = None, | |||
hostname: str = socket.gethostname(), | |||
heartbeat_interval: timedelta = timedelta(milliseconds=100), | |||
checkpoint_transport: Optional[CheckpointTransport[Dict[str, T]]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do you envision the CheckpointTransport
abstraction will be extended and provided by users?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not really anticipating users implementing one from scratch but I do want to provide two implementations here. One with CheckpointServer and another that's PG based
Advanced users can then use those by themselves or patch things like CheckpointServer to provide more robust implementations for their use cases
src/manager.rs
Outdated
}) | ||
.collect(); | ||
let all_recovering_ranks_set = all_recovering_ranks.iter().collect::<HashSet<_>>(); | ||
let serving_ranks: Vec<usize> = participants |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I was initially confused by the name serving_ranks, maybe "stable_ranks" or "up_to_date_ranks" would also make sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
torchft/manager.py
Outdated
self._rank, timeout=self._timeout | ||
) | ||
if allow_heal: | ||
if quorum.recovering_ranks: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the if recovering_ranks
condition is always true if if heal
is true as well right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They're flipped, this checks if anyone will be requesting a checkpoint from this node where as heal
below is whether this node will be requesting a checkpoint
heal == quroum.recover_rank is not None
which is distinct from quorum.recovering_ranks
|
||
// Compute the details for workers at max step. | ||
let max_step = participants.iter().map(|p| p.step).max().unwrap(); | ||
let max_participants: Vec<&QuorumMember> = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just checking, once we reach into compute_quorum_results will there always be a nonzero number of ranks at max_step in the quorum?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, there's always at least 1 member at max_step since it's defined as the maximum step of participating members and we'll always have at least 1 member
src/manager.rs
Outdated
Ok(ManagerQuorumResponse { | ||
quorum_id: quorum.quorum_id, | ||
// address is used for looking up the checkpoint server address. | ||
recover_manager_address: primary.address.clone(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how/why do we use the recover_manager_address
if the checkpoint loading is done round robin?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch -- this shouldn't be the primary it should be the recover_replica manager address 🤦
src/lib.rs
Outdated
replica_rank: 0, | ||
replica_world_size: 1, | ||
recover_manager_address: "".to_string(), | ||
recover_rank: None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe rename recover_rank
to recovery_source_rank
? Or add comment that clarifies this represents the rank from which the current replica will recover
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
renamed to recover_src_rank
and recover_dst_ranks
0e29ef9
to
0033713
Compare
This is a major refactoring of the live checkpointing code.
New behavior:
Key refactors:
CheckpointTransport
whichCheckpointServer
implements. This abstraction is designed to be symmetric for recovery strategies that require both sender and receiver to be aware of the communication. This will allow for recovering via ProcessGroups/NCCL which can be significantly faster.ManagerClient.quorum
return aQuorumResult
struct rather than a named tuplecheckpoint_address
tocheckpoint_metadata
as not all transports will be using an addresscompute_quorum_result
method and added better unit tests for itTest plan: