Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/advanced/edge/jobs/pt_job_adv.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def main():
max_model_version = 200
max_model_history = None
min_hole_to_fill = 10

eval_frequency = 1
local_batch_size = 10
local_epochs = 4
Expand Down Expand Up @@ -72,6 +73,8 @@ def main():
)
device_manager_config = DeviceManagerConfig(
device_selection_size=device_selection_size,
# wait for all clients report to server before starting
initial_min_client_num=num_leaf_nodes,
min_hole_to_fill=min_hole_to_fill,
device_reuse=False,
)
Expand Down
155 changes: 142 additions & 13 deletions nvflare/edge/assessors/buff_device_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,22 @@
# limitations under the License.

import random
from collections import Counter, defaultdict
from typing import Dict, Set

from nvflare.edge.assessors.device_manager import DeviceManager
from nvflare.edge.mud import PropKey
from nvflare.fuel.utils.validation_utils import check_positive_int


class BuffDeviceManager(DeviceManager):
def __init__(
self,
device_selection_size: int,
initial_min_client_num: int = 1,
min_hole_to_fill: int = 1,
device_reuse: bool = True,
device_sampling_strategy: str = "balanced",
):
"""Initialize the BuffDeviceManager.
BuffDeviceManager is responsible for managing the selection of devices for model training.
Expand All @@ -36,18 +40,120 @@ def __init__(
The device_reuse flag indicates whether devices can be reused across different model versions, if False, we will always select new devices when filling holes.
Args:
device_selection_size (int): Number of devices to select for each model update round.
initial_min_client_num (int): Minimum number of clients to have at the beginning. This can be useful for initial model dispatch.
min_hole_to_fill (int): Minimum number of empty slots in device selection before refilling. Defaults to 1 - once received an update, immediately sample a new device and send the current task to it.
device_reuse (bool): Whether to allow reusing devices across different model versions. Defaults to True.
device_sampling_strategy (str): Strategy for sampling devices when filling selection. Defaults to "balanced".
- "balanced": try to balance the usage of devices across clients.
- "random": randomly select devices from the available pool.
"""
super().__init__()
check_positive_int("device_selection_size", device_selection_size)
check_positive_int("min_hole_to_fill", min_hole_to_fill)

self.device_selection_size = device_selection_size
self.initial_min_client_num = initial_min_client_num
self.min_hole_to_fill = min_hole_to_fill
self.device_reuse = device_reuse
self.device_sampling_strategy = device_sampling_strategy
# also keep track of the current selection version and used devices
self.current_selection_version = 0
self.used_devices = {}
# keep a map of device_id -> client_name
self.device_client_map = {}

def _balanced_device_sampling(self, usable_devices: Set[str], num_holes: int) -> Set[str]:
"""Sample devices while balancing across clients.

Args:
usable_devices: Set of device IDs that can be selected
num_holes: Number of devices to sample

Returns:
Set of selected device IDs
"""
if not usable_devices or num_holes <= 0:
return set()

# Count devices per client efficiently using Counter
client_device_counts = Counter(
self.device_client_map[device_id] for device_id in usable_devices if device_id in self.device_client_map
)

# Group devices by client using defaultdict for efficiency
client_devices = defaultdict(list)
for device_id in usable_devices:
if device_id in self.device_client_map:
client_devices[self.device_client_map[device_id]].append(device_id)

if not client_device_counts:
# Fallback to random sampling if no client mapping
return set(random.sample(list(usable_devices), min(num_holes, len(usable_devices))))

# Randomize client order for more balanced distribution
clients_list = list(client_device_counts.items())
random.shuffle(clients_list)

selected_devices = set()
remaining_holes = num_holes

# First pass: assign minimum possible to each client
min_per_client = remaining_holes // len(clients_list)
extra_holes = remaining_holes % len(clients_list)

for i, (client_name, device_count) in enumerate(clients_list):
# Calculate how many devices this client should get
if i < extra_holes:
target_count = min_per_client + 1
else:
target_count = min_per_client

# Don't exceed what the client has available
actual_count = min(target_count, device_count)

if actual_count > 0:
# Randomly sample from this client's devices
sampled = random.sample(client_devices[client_name], actual_count)
selected_devices.update(sampled)
remaining_holes -= actual_count

# Remove selected devices from available pool
client_devices[client_name] = [d for d in client_devices[client_name] if d not in sampled]

# Second pass: if we still have holes and some clients have remaining devices,
# distribute remaining holes as evenly as possible with random starting point
if remaining_holes > 0:
clients_with_devices = [(name, devices) for name, devices in client_devices.items() if devices]

if clients_with_devices:
# Shuffle clients to randomize the round-robin starting point
random.shuffle(clients_with_devices)

# Round-robin distribution of remaining holes
client_idx = 0
while remaining_holes > 0 and clients_with_devices:
client_name, devices = clients_with_devices[client_idx]

if devices:
# Take one device from this client
device_id = random.choice(devices)
selected_devices.add(device_id)
devices.remove(device_id)
remaining_holes -= 1

# Remove client if no more devices
if not devices:
clients_with_devices.pop(client_idx)
if clients_with_devices:
client_idx = client_idx % len(clients_with_devices)
else:
client_idx = (client_idx + 1) % len(clients_with_devices)
else:
clients_with_devices.pop(client_idx)
if clients_with_devices:
client_idx = client_idx % len(clients_with_devices)

return selected_devices

def update_available_devices(self, devices: Dict, fl_ctx) -> None:
self.available_devices.update(devices)
Expand All @@ -56,6 +162,11 @@ def update_available_devices(self, devices: Dict, fl_ctx) -> None:
f"assessor got reported {len(devices)} available devices from child. "
f"total num available devices: {len(self.available_devices)}",
)
# add new devices to device_client_map
for device_id, device in devices.items():
client_name = device.to_dict().get(PropKey.CLIENT_NAME)
if client_name:
self.device_client_map[device_id] = client_name

def fill_selection(self, current_model_version: int, fl_ctx) -> None:
num_holes = self.device_selection_size - len(self.current_selection)
Expand All @@ -66,17 +177,30 @@ def fill_selection(self, current_model_version: int, fl_ctx) -> None:
usable_devices = set(self.available_devices.keys()) - set(self.used_devices.keys())

if usable_devices:
for _ in range(num_holes):
device_id = random.choice(list(usable_devices))
usable_devices.remove(device_id)
# current_selection keeps track of devices selected for a particular model version
self.current_selection[device_id] = current_model_version
self.used_devices[device_id] = {
"model_version": current_model_version,
"selection_version": self.current_selection_version,
}
if not usable_devices:
break
if self.device_sampling_strategy == "balanced":
# try to balance the usage of devices across clients
selected_devices = self._balanced_device_sampling(usable_devices, num_holes)
for device_id in selected_devices:
# current_selection keeps track of devices selected for a particular model version
self.current_selection[device_id] = self.current_selection_version
self.used_devices[device_id] = {
"model_version": current_model_version,
"selection_version": self.current_selection_version,
}
elif self.device_sampling_strategy == "random":
for _ in range(num_holes):
device_id = random.choice(list(usable_devices))
usable_devices.remove(device_id)
# current_selection keeps track of devices selected for a particular model version
self.current_selection[device_id] = self.current_selection_version
self.used_devices[device_id] = {
"model_version": current_model_version,
"selection_version": self.current_selection_version,
}
if not usable_devices:
break
else:
raise ValueError(f"Invalid device sampling strategy: {self.device_sampling_strategy}")
self.log_info(
fl_ctx,
f"current selection with {len(self.current_selection)} items: V{self.current_selection_version}; {dict(sorted(self.current_selection.items()))}",
Expand All @@ -95,11 +219,16 @@ def remove_devices_from_used(self, devices: Set[str], fl_ctx) -> None:
for device_id in devices:
self.used_devices.pop(device_id, None)

def has_enough_devices(self, fl_ctx) -> bool:
def has_enough_devices_and_clients(self, fl_ctx) -> bool:
num_holes = self.device_selection_size - len(self.current_selection)
usable_devices = set(self.available_devices.keys()) - set(self.used_devices.keys())
num_usable_devices = len(usable_devices)
return num_usable_devices >= num_holes
if num_usable_devices < num_holes:
return False

# Further check if we have enough clients
unique_clients = set(self.device_client_map.values())
return len(unique_clients) >= self.initial_min_client_num

def should_fill_selection(self, fl_ctx) -> bool:
num_holes = self.device_selection_size - len(self.current_selection)
Expand Down
6 changes: 3 additions & 3 deletions nvflare/edge/assessors/device_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,14 @@ def should_fill_selection(self, fl_ctx: FLContext) -> bool:
pass

@abstractmethod
def has_enough_devices(self, fl_ctx: FLContext) -> bool:
"""Check if there are enough devices available to start task distribution.
def has_enough_devices_and_clients(self, fl_ctx: FLContext) -> bool:
"""Check if there are enough devices and clients available to start task distribution.

Args:
fl_ctx: FLContext object

Returns:
bool: True if there are enough devices to start task distribution, False otherwise
bool: True if there are enough devices and clients to start task distribution, False otherwise
"""
pass

Expand Down
4 changes: 2 additions & 2 deletions nvflare/edge/assessors/model_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def _do_child_update(self, update: Shareable, fl_ctx: FLContext) -> (bool, Optio
if report.available_devices:
self.device_manager.update_available_devices(report.available_devices, fl_ctx)
# Reset wait timer if we now have enough devices
if self.device_wait_start_time is not None and self.device_manager.has_enough_devices(fl_ctx):
if self.device_wait_start_time is not None and self.device_manager.has_enough_devices_and_clients(fl_ctx):
self.device_wait_start_time = None
self.log_info(fl_ctx, "Sufficient devices now available, resetting wait timer")

Expand Down Expand Up @@ -235,7 +235,7 @@ def _do_child_update(self, update: Shareable, fl_ctx: FLContext) -> (bool, Optio
# Handle device selection
if self.device_manager.should_fill_selection(fl_ctx):
# check if we have enough devices to fill selection
if self.device_manager.has_enough_devices(fl_ctx):
if self.device_manager.has_enough_devices_and_clients(fl_ctx):
if self.model_manager.current_model_version == 0:
self.log_info(fl_ctx, "Generate initial model and fill selection")
self.model_manager.generate_new_model(fl_ctx)
Expand Down
13 changes: 13 additions & 0 deletions nvflare/edge/tools/edge_fed_buff_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ class DeviceManagerConfig:
Attributes:
device_selection_size: Number of devices to select for each training round.
Default: 100
initial_min_client_num: Minimum number of clients to have at the beginning.
This can be useful for initial model dispatch.
Default: 1
min_hole_to_fill: Minimum number of model updates to wait for before
sampling the next batch of devices and dispatching the current global model.
- If set to 1, the server immediately dispatch the current global model to a sampled device.
Expand All @@ -95,20 +98,28 @@ class DeviceManagerConfig:
if False, devices will be selected only once, which could be realistic for real-world scenarios where the
device pool is huge while participation is random.
Default: True (always reuse / include the existing devices for further learning)
device_sampling_strategy: Strategy for sampling devices when filling selection.
- "balanced": try to balance the usage of devices across clients.
- "random": randomly select devices from the available pool.
Default: "balanced"
"""

def __init__(
self,
device_selection_size: int = 100,
initial_min_client_num: int = 1,
min_hole_to_fill: int = 1,
device_reuse: bool = True,
device_sampling_strategy: str = "balanced",
):
self.device_selection_size = device_selection_size
self.initial_min_client_num = initial_min_client_num
self.min_hole_to_fill = min_hole_to_fill
# check if min_hole_to_fill is smaller than device_selection_size
if min_hole_to_fill > device_selection_size:
raise ValueError("min_hole_to_fill needs to be smaller than or equal to device_selection_size")
self.device_reuse = device_reuse
self.device_sampling_strategy = device_sampling_strategy


class SimulationConfig:
Expand Down Expand Up @@ -310,8 +321,10 @@ def _configure_job(self, job: EdgeJob):

device_manager = BuffDeviceManager(
device_selection_size=self.device_manager_config.device_selection_size,
initial_min_client_num=self.device_manager_config.initial_min_client_num,
min_hole_to_fill=self.device_manager_config.min_hole_to_fill,
device_reuse=self.device_manager_config.device_reuse,
device_sampling_strategy=self.device_manager_config.device_sampling_strategy,
)
device_manager_id = job.to_server(device_manager, id="device_manager")

Expand Down
Loading