diff --git a/docs/programming_guide/memory_management.rst b/docs/programming_guide/memory_management.rst index ea7e972799..8ba6b91e0a 100644 --- a/docs/programming_guide/memory_management.rst +++ b/docs/programming_guide/memory_management.rst @@ -5,7 +5,7 @@ Memory Management ################### This guide describes memory management techniques for long-running federated learning jobs -using Python, PyTorch, and glibc. +using Python, PyTorch, and glibc/jemalloc. .. contents:: Table of Contents :local: @@ -21,7 +21,24 @@ RSS (Resident Set Size) can grow continuously due to: - glibc memory arena fragmentation - PyTorch CUDA cache retention -NVFlare provides utilities and configuration options to manage memory effectively. +NVFlare provides utilities and configuration options to manage memory effectively on +both server and client sides. The framework automatically detects the memory allocator +in use (glibc or jemalloc) and adapts its cleanup strategy accordingly. + +Allocator Support +================= + +NVFlare supports two memory allocators: + +**glibc (default on most Linux)** + Uses ``malloc_trim()`` to release free heap pages to the OS. + Requires ``MALLOC_ARENA_MAX`` for optimal memory behavior. + +**jemalloc (recommended for PyTorch)** + Uses auto-decay for memory management. Configure via ``MALLOC_CONF``. + No ``malloc_trim()`` calls needed (jemalloc handles this automatically). + +NVFlare automatically detects which allocator is in use at runtime. Platform Compatibility ====================== @@ -139,14 +156,130 @@ Memory cleanup has minimal overhead in typical federated learning workloads: management with negligible performance impact. Only disable (``=0``) if you've measured and confirmed RSS is stable without cleanup. +Client-Side Memory Cleanup +========================== + +The FedAvg recipe and ScriptRunner support automatic memory cleanup on clients via +``client_memory_gc_rounds`` and ``cuda_empty_cache`` parameters. + +Configuration +------------- + +.. code-block:: python + + from nvflare.recipe.fedavg import FedAvgRecipe + + recipe = FedAvgRecipe( + name="my_job", + min_clients=4, + num_rounds=100, + train_script="client.py", + + # Server-side cleanup + server_memory_gc_rounds=5, + + # Client-side cleanup + client_memory_gc_rounds=1, # Cleanup every round + cuda_empty_cache=True, # Clear GPU cache + ) + +Swarm Learning Configuration +---------------------------- + +Swarm Learning uses ``memory_gc_rounds`` (not ``memory_gc_counts``) and +``cuda_empty_cache`` on ``SimpleSwarmLearningRecipe``: + +.. code-block:: python + + from nvflare.app_opt.pt.recipes.swarm import SimpleSwarmLearningRecipe + + recipe = SimpleSwarmLearningRecipe( + name="swarm_job", + model=MyModel(), + num_rounds=10, + train_script="train.py", + memory_gc_rounds=1, # Cleanup every round on trainer and aggregator roles + cuda_empty_cache=True, + ) + +.. note:: + + ``memory_gc_rounds`` and ``cuda_empty_cache`` are top-level Swarm recipe arguments. + Do not pass them inside ``train_args`` (they are reserved keys). + +**Parameters:** + +- ``client_memory_gc_rounds``: Run cleanup every N rounds on client (0 = disabled) +- ``cuda_empty_cache``: If True, call ``torch.cuda.empty_cache()`` on cleanup +- ``memory_gc_rounds`` (Swarm): Run cleanup every N rounds (0 = disabled) + +What It Does +------------ + +When enabled, after each ``flare.send()`` on the client: + +1. Runs Python garbage collection (``gc.collect()``) +2. For glibc: Returns free heap pages to OS (``malloc_trim()``) +3. For jemalloc: Relies on auto-decay (no manual action needed) +4. Optionally clears PyTorch CUDA cache + +**Note:** The cleanup is transparent to the user's training script. No code changes +are required in ``train.py``. + +External Process Support +------------------------ + +For external process execution (``launch_external_process=True``), memory settings +are passed via environment variables: + +- ``NVFLARE_CLIENT_MEMORY_GC_ROUNDS``: Cleanup interval +- ``NVFLARE_CUDA_EMPTY_CACHE``: GPU cache cleanup (``true``/``false``) + Recommended Settings ==================== -+--------+-------------------------------+----------------------+ -| Role | ``server_memory_gc_rounds`` | ``MALLOC_ARENA_MAX`` | -+========+===============================+======================+ -| Server | 5 | 4 | -+--------+-------------------------------+----------------------+ ++--------+-----------------------------+-----------------------------+----------------------+----------------------+ +| Role | ``server_memory_gc_rounds`` | ``client_memory_gc_rounds`` | ``MALLOC_ARENA_MAX`` | ``cuda_empty_cache`` | ++========+=============================+=============================+======================+======================+ +| Server | 5 | N/A | 4 | N/A | ++--------+-----------------------------+-----------------------------+----------------------+----------------------+ +| Client | N/A | 1 | 2 | True (for GPU) | ++--------+-----------------------------+-----------------------------+----------------------+----------------------+ + +Using jemalloc +============== + +For PyTorch workloads, jemalloc is recommended over glibc malloc. NVFlare's startup +scripts automatically detect and use jemalloc if available. + +Startup Script +-------------- + +The generated ``sub_start.sh`` script includes jemalloc auto-detection: + +.. code-block:: bash + + # Auto-detects jemalloc at standard locations + for JEMALLOC in /usr/lib/x86_64-linux-gnu/libjemalloc.so.2 \ + /usr/lib64/libjemalloc.so.2 \ + /usr/local/lib/libjemalloc.so; do + if [ -f "$JEMALLOC" ]; then + export LD_PRELOAD="${LD_PRELOAD:+$LD_PRELOAD:}$JEMALLOC" + export MALLOC_CONF="${MALLOC_CONF:-dirty_decay_ms:5000,muzzy_decay_ms:5000}" + break + fi + done + +Installing jemalloc +------------------- + +.. code-block:: bash + + # Ubuntu/Debian + apt-get install libjemalloc2 + + # RHEL/CentOS + yum install jemalloc API Reference ============= @@ -158,15 +291,29 @@ cleanup_memory from nvflare.fuel.utils.memory_utils import cleanup_memory - cleanup_memory(torch_cuda_empty_cache=True) + cleanup_memory(cuda_empty_cache=True) -**Signature:** ``cleanup_memory(torch_cuda_empty_cache: bool = False) -> None`` +**Signature:** ``cleanup_memory(cuda_empty_cache: bool = False) -> None`` -Performs memory cleanup: +Performs allocator-aware memory cleanup: 1. Runs ``gc.collect()`` -2. Calls ``malloc_trim(0)`` (Linux/glibc only, safe no-op elsewhere) -3. Optionally calls ``torch.cuda.empty_cache()`` +2. For glibc: Calls ``malloc_trim(0)`` +3. For jemalloc: Relies on auto-decay (no action needed) +4. Optionally calls ``torch.cuda.empty_cache()`` + +get_allocator_type +------------------ + +.. code-block:: python + + from nvflare.fuel.utils.memory_utils import get_allocator_type + + allocator = get_allocator_type() # "glibc", "jemalloc", or "unknown" + +**Signature:** ``get_allocator_type() -> str`` + +Detects which memory allocator is in use at runtime. Result is cached. try_malloc_trim --------------- @@ -195,12 +342,21 @@ High RSS on Server 1. Check ``MALLOC_ARENA_MAX`` is set 2. Enable ``server_memory_gc_rounds=5`` -3. Monitor with ``top`` or ``htop`` +3. Consider using jemalloc (LD_PRELOAD) +4. Monitor with ``top`` or ``htop`` + +High RSS on Client +------------------ + +1. Check ``MALLOC_ARENA_MAX=2`` is set +2. Enable ``client_memory_gc_rounds=1`` +3. Enable ``cuda_empty_cache=True`` for GPU +4. Consider using jemalloc OOM Errors ---------- 1. Reduce batch size -2. Enable memory cleanup every round (``server_memory_gc_rounds=1``) +2. Enable memory cleanup every round (``client_memory_gc_rounds=1`` or ``server_memory_gc_rounds=1``) 3. Check for memory leaks in training code - +4. Use jemalloc with appropriate decay settings diff --git a/nvflare/app_common/ccwf/ccwf_job.py b/nvflare/app_common/ccwf/ccwf_job.py index d36df6d7ac..c81b0e70bc 100644 --- a/nvflare/app_common/ccwf/ccwf_job.py +++ b/nvflare/app_common/ccwf/ccwf_job.py @@ -84,6 +84,8 @@ def __init__( request_to_submit_result_interval: float = 1.0, max_concurrent_submissions: int = 1, enable_tensor_disk_offload: bool = False, + memory_gc_rounds: int = 1, + cuda_empty_cache: bool = False, ): # the executor could be a wrapper object that adds real Executor when added to job! validate_object_for_job("executor", executor, Executor) @@ -115,6 +117,8 @@ def __init__( self.request_to_submit_result_interval = request_to_submit_result_interval self.max_concurrent_submissions = max_concurrent_submissions self.enable_tensor_disk_offload = enable_tensor_disk_offload + self.memory_gc_rounds = memory_gc_rounds + self.cuda_empty_cache = cuda_empty_cache class CyclicServerConfig: @@ -275,6 +279,8 @@ def add_swarm( request_to_submit_result_interval=client_config.request_to_submit_result_interval, max_concurrent_submissions=client_config.max_concurrent_submissions, enable_tensor_disk_offload=client_config.enable_tensor_disk_offload, + memory_gc_rounds=client_config.memory_gc_rounds, + cuda_empty_cache=client_config.cuda_empty_cache, ) self.to_clients(client_controller, tasks=["swarm_*"]) if not self.executor: diff --git a/nvflare/app_common/ccwf/swarm_client_ctl.py b/nvflare/app_common/ccwf/swarm_client_ctl.py index 3a4a3b706a..b8a2c0212d 100644 --- a/nvflare/app_common/ccwf/swarm_client_ctl.py +++ b/nvflare/app_common/ccwf/swarm_client_ctl.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy -import gc import random import threading import time @@ -110,9 +109,6 @@ def gather(self, client_name: str, result: Shareable, fl_ctx: FLContext) -> Shar if client_status: client_status.busy = False - # force garbage collection after each gather - gc.collect() - def can_accept_submission(self, client_name: str, result: Shareable, fl_ctx: FLContext) -> str: with self.perm_lock: result_round = result.get_header(AppConstants.CURRENT_ROUND) @@ -295,6 +291,8 @@ def __init__( request_to_submit_result_interval: float = 1.0, max_concurrent_submissions: int = 1, enable_tensor_disk_offload: bool = False, + memory_gc_rounds: int = 1, + cuda_empty_cache: bool = False, ): """ Constructor of a ClientSideController object. @@ -325,6 +323,12 @@ def __init__( max_concurrent_submissions: max number of concurrent submissions allowed on the aggregation client. enable_tensor_disk_offload: download tensors to disk during FOBS streaming instead of into memory. Reduces peak memory during aggregation. Aggregators must handle lazy refs. + memory_gc_rounds: run gc.collect() + malloc_trim on the aggregator every N FL rounds. + Defaults to 1 (every round) to match legacy behavior where gc.collect() was called + unconditionally after each trainer submission. Set to 0 to disable. + cuda_empty_cache: also call torch.cuda.empty_cache() during aggregator-side cleanup. + In swarm learning the aggregator runs on the same client as the trainer, so GPU + memory may be relevant. Defaults to False. Note that if the max_concurrent_submissions is set to 1, it practically means that all training results will be submitted to the aggregation client sequentially. This lowers the resource pressure on @@ -386,6 +390,9 @@ def __init__( self.last_aggr_round_done = -1 self.enable_tensor_disk_offload = enable_tensor_disk_offload self._previous_enable_tensor_disk_offload = None + self.memory_gc_rounds = memory_gc_rounds + self.cuda_empty_cache = cuda_empty_cache + self._aggr_round_count = 0 def process_config(self, fl_ctx: FLContext): all_clients = self.get_config_prop(Constant.CLIENTS) @@ -585,22 +592,29 @@ def _end_gather(self, gatherer: Gatherer): # determine the best global result self._distribute_final_results(aggr_result, fl_ctx) - return - # continue next round - next_round_data = self.shareable_generator.learnable_to_shareable(global_weights, fl_ctx) - assert isinstance(next_round_data, Shareable) + else: + # continue next round + next_round_data = self.shareable_generator.learnable_to_shareable(global_weights, fl_ctx) + assert isinstance(next_round_data, Shareable) - best_round = aggr_result.get_header(Constant.ROUND) - best_metric = aggr_result.get_header(Constant.METRIC) - best_client = aggr_result.get_header(Constant.CLIENT) + best_round = aggr_result.get_header(Constant.ROUND) + best_metric = aggr_result.get_header(Constant.METRIC) + best_client = aggr_result.get_header(Constant.CLIENT) - if best_client: - next_round_data.set_header(Constant.ROUND, best_round) - next_round_data.set_header(Constant.CLIENT, best_client) - next_round_data.set_header(Constant.METRIC, best_metric) + if best_client: + next_round_data.set_header(Constant.ROUND, best_round) + next_round_data.set_header(Constant.CLIENT, best_client) + next_round_data.set_header(Constant.METRIC, best_metric) + + self._scatter(next_round_data, gatherer.for_round + 1, gatherer.fl_ctx) + + if self.memory_gc_rounds > 0: + self._aggr_round_count += 1 + if self._aggr_round_count % self.memory_gc_rounds == 0: + from nvflare.fuel.utils.memory_utils import cleanup_memory - self._scatter(next_round_data, gatherer.for_round + 1, gatherer.fl_ctx) + cleanup_memory(cuda_empty_cache=self.cuda_empty_cache) def _ask_to_share_best_result(self, client: str, metric, fl_ctx: FLContext): # other client has best model - ask it to distribute its result diff --git a/nvflare/app_common/executors/client_api_launcher_executor.py b/nvflare/app_common/executors/client_api_launcher_executor.py index 1c8bd2a6e3..711d833b0c 100644 --- a/nvflare/app_common/executors/client_api_launcher_executor.py +++ b/nvflare/app_common/executors/client_api_launcher_executor.py @@ -50,6 +50,8 @@ def __init__( params_transfer_type: str = TransferType.FULL, config_file_name: str = CLIENT_API_CONFIG, server_expected_format: str = ExchangeFormat.NUMPY, + memory_gc_rounds: int = 0, + cuda_empty_cache: bool = False, ) -> None: """Initializes the ClientAPILauncherExecutor. @@ -107,6 +109,8 @@ def __init__( self._params_exchange_format = params_exchange_format self._params_transfer_type = params_transfer_type self._config_file_name = config_file_name + self._memory_gc_rounds = memory_gc_rounds + self._cuda_empty_cache = cuda_empty_cache def initialize(self, fl_ctx: FLContext) -> None: self.prepare_config_for_launch(fl_ctx) @@ -141,6 +145,8 @@ def prepare_config_for_launch(self, fl_ctx: FLContext): ConfigKey.ARG: pipe_export_args, }, ConfigKey.HEARTBEAT_TIMEOUT: self.heartbeat_timeout, + ConfigKey.MEMORY_GC_ROUNDS: self._memory_gc_rounds, + ConfigKey.CUDA_EMPTY_CACHE: self._cuda_empty_cache, } config_data = { diff --git a/nvflare/app_common/executors/in_process_client_api_executor.py b/nvflare/app_common/executors/in_process_client_api_executor.py index 6f233b351c..b3c4ee556c 100644 --- a/nvflare/app_common/executors/in_process_client_api_executor.py +++ b/nvflare/app_common/executors/in_process_client_api_executor.py @@ -61,8 +61,12 @@ def __init__( evaluate_task_name: str = AppConstants.TASK_VALIDATION, submit_model_task_name: str = AppConstants.TASK_SUBMIT_MODEL, server_expected_format: str = ExchangeFormat.NUMPY, + memory_gc_rounds: int = 0, + cuda_empty_cache: bool = False, ): super(InProcessClientAPIExecutor, self).__init__() + self._memory_gc_rounds = memory_gc_rounds + self._cuda_empty_cache = cuda_empty_cache self._abort = False self._client_api = None self._result_pull_interval = result_pull_interval @@ -121,6 +125,11 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): meta = self._prepare_task_meta(fl_ctx, None) self._client_api = InProcessClientAPI(task_metadata=meta, result_check_interval=self._result_pull_interval) self._client_api.init() + # Configure memory management if enabled + if self._memory_gc_rounds > 0: + self._client_api.configure_memory_management( + gc_rounds=self._memory_gc_rounds, cuda_empty_cache=self._cuda_empty_cache + ) self._data_bus.put_data(CLIENT_API_KEY, self._client_api) self._task_fn_thread.start() diff --git a/nvflare/app_common/np/recipes/cross_site_eval.py b/nvflare/app_common/np/recipes/cross_site_eval.py index 7259a3004d..d3a59c2c2a 100644 --- a/nvflare/app_common/np/recipes/cross_site_eval.py +++ b/nvflare/app_common/np/recipes/cross_site_eval.py @@ -40,6 +40,8 @@ class _CrossSiteEvalValidator(BaseModel): model_name: Optional[dict] = None submit_model_timeout: int = 600 validation_timeout: int = 6000 + client_memory_gc_rounds: int = 0 + cuda_empty_cache: bool = False @field_validator("initial_ckpt") @classmethod @@ -118,6 +120,8 @@ def __init__( model_name: Optional[dict] = None, submit_model_timeout: int = 600, validation_timeout: int = 6000, + client_memory_gc_rounds: int = 0, + cuda_empty_cache: bool = False, ): # Validate all inputs _CrossSiteEvalValidator( @@ -132,6 +136,8 @@ def __init__( model_name=model_name, submit_model_timeout=submit_model_timeout, validation_timeout=validation_timeout, + client_memory_gc_rounds=client_memory_gc_rounds, + cuda_empty_cache=cuda_empty_cache, ) job = FedJob(name=name, min_clients=min_clients) @@ -173,6 +179,8 @@ def __init__( launch_external_process=launch_external_process, command=command, framework=FrameworkType.RAW, + memory_gc_rounds=client_memory_gc_rounds, + cuda_empty_cache=cuda_empty_cache, ) job.to_clients(executor, tasks=[AppConstants.TASK_VALIDATION]) else: diff --git a/nvflare/app_common/np/recipes/fedavg.py b/nvflare/app_common/np/recipes/fedavg.py index d81a83b9aa..65cf9490e5 100644 --- a/nvflare/app_common/np/recipes/fedavg.py +++ b/nvflare/app_common/np/recipes/fedavg.py @@ -128,6 +128,8 @@ def __init__( save_filename: str = "FL_global_model.pt", exclude_vars: Optional[str] = None, aggregation_weights: Optional[Dict[str, float]] = None, + client_memory_gc_rounds: int = 0, + cuda_empty_cache: bool = False, ): # Store model and initial_ckpt for NumPy-specific setup (model wins over initial_model for 2.7 compat) self._np_model = model if model is not None else initial_model @@ -159,6 +161,8 @@ def __init__( save_filename=save_filename, exclude_vars=exclude_vars, aggregation_weights=aggregation_weights, + client_memory_gc_rounds=client_memory_gc_rounds, + cuda_empty_cache=cuda_empty_cache, ) # Override framework for cross-site evaluation compatibility diff --git a/nvflare/app_common/np/recipes/lr/fedavg.py b/nvflare/app_common/np/recipes/lr/fedavg.py index 73052bfe4a..a7b0932f92 100644 --- a/nvflare/app_common/np/recipes/lr/fedavg.py +++ b/nvflare/app_common/np/recipes/lr/fedavg.py @@ -37,6 +37,8 @@ class _FedAvgValidator(BaseModel): train_args: str launch_external_process: bool = False command: str + client_memory_gc_rounds: int = 0 + cuda_empty_cache: bool = False @field_validator("initial_ckpt") @classmethod @@ -97,6 +99,8 @@ def __init__( train_args: str = "", launch_external_process=False, command: str = "python3 -u", + client_memory_gc_rounds: int = 0, + cuda_empty_cache: bool = False, ): # Validate inputs internally v = _FedAvgValidator( @@ -110,6 +114,8 @@ def __init__( train_args=train_args, launch_external_process=launch_external_process, command=command, + client_memory_gc_rounds=client_memory_gc_rounds, + cuda_empty_cache=cuda_empty_cache, ) self.name = v.name @@ -122,6 +128,8 @@ def __init__( self.launch_external_process = v.launch_external_process self.command = v.command self.num_features = v.num_features + self.client_memory_gc_rounds = v.client_memory_gc_rounds + self.cuda_empty_cache = v.cuda_empty_cache # Create FedJob. job = FedJob(name=self.name, min_clients=self.min_clients) @@ -152,6 +160,8 @@ def __init__( framework=FrameworkType.RAW, server_expected_format=ExchangeFormat.RAW, params_transfer_type=TransferType.FULL, + memory_gc_rounds=self.client_memory_gc_rounds, + cuda_empty_cache=self.cuda_empty_cache, ) job.to_clients(runner) diff --git a/nvflare/app_opt/pt/client_api_launcher_executor.py b/nvflare/app_opt/pt/client_api_launcher_executor.py index eebd584874..3591adb9ae 100644 --- a/nvflare/app_opt/pt/client_api_launcher_executor.py +++ b/nvflare/app_opt/pt/client_api_launcher_executor.py @@ -49,6 +49,8 @@ def __init__( params_exchange_format: str = ExchangeFormat.PYTORCH, params_transfer_type: str = TransferType.FULL, config_file_name: str = CLIENT_API_CONFIG, + memory_gc_rounds: int = 0, + cuda_empty_cache: bool = False, ) -> None: ClientAPILauncherExecutor.__init__( self, @@ -74,6 +76,8 @@ def __init__( params_exchange_format=params_exchange_format, params_transfer_type=params_transfer_type, config_file_name=config_file_name, + memory_gc_rounds=memory_gc_rounds, + cuda_empty_cache=cuda_empty_cache, ) def initialize(self, fl_ctx: FLContext) -> None: diff --git a/nvflare/app_opt/pt/in_process_client_api_executor.py b/nvflare/app_opt/pt/in_process_client_api_executor.py index 285c5d7ace..93f1541899 100644 --- a/nvflare/app_opt/pt/in_process_client_api_executor.py +++ b/nvflare/app_opt/pt/in_process_client_api_executor.py @@ -38,6 +38,8 @@ def __init__( submit_model_task_name: str = AppConstants.TASK_SUBMIT_MODEL, params_exchange_format=ExchangeFormat.PYTORCH, server_expected_format=ExchangeFormat.NUMPY, + memory_gc_rounds: int = 0, + cuda_empty_cache: bool = False, ): super(PTInProcessClientAPIExecutor, self).__init__( task_script_path=task_script_path, @@ -54,6 +56,8 @@ def __init__( params_transfer_type=params_transfer_type, log_pull_interval=log_pull_interval, server_expected_format=server_expected_format, + memory_gc_rounds=memory_gc_rounds, + cuda_empty_cache=cuda_empty_cache, ) fobs.register(TensorDecomposer) if ( diff --git a/nvflare/app_opt/pt/recipes/cyclic.py b/nvflare/app_opt/pt/recipes/cyclic.py index 6145da312f..a7da973173 100644 --- a/nvflare/app_opt/pt/recipes/cyclic.py +++ b/nvflare/app_opt/pt/recipes/cyclic.py @@ -62,6 +62,8 @@ def __init__( server_expected_format: ExchangeFormat = ExchangeFormat.NUMPY, params_transfer_type: TransferType = TransferType.FULL, server_memory_gc_rounds: int = 1, + client_memory_gc_rounds: int = 0, + cuda_empty_cache: bool = False, ): # Validate initial_ckpt early (base class won't see it since we pass None) from nvflare.recipe.utils import validate_ckpt @@ -93,6 +95,8 @@ def __init__( server_expected_format=server_expected_format, params_transfer_type=params_transfer_type, server_memory_gc_rounds=server_memory_gc_rounds, + client_memory_gc_rounds=client_memory_gc_rounds, + cuda_empty_cache=cuda_empty_cache, ) def _setup_model_and_persistor(self, job) -> str: diff --git a/nvflare/app_opt/pt/recipes/fedavg.py b/nvflare/app_opt/pt/recipes/fedavg.py index dca486cab8..58d5324c33 100644 --- a/nvflare/app_opt/pt/recipes/fedavg.py +++ b/nvflare/app_opt/pt/recipes/fedavg.py @@ -126,6 +126,8 @@ def __init__( aggregation_weights: Optional[dict[str, float]] = None, server_memory_gc_rounds: int = 0, enable_tensor_disk_offload: bool = False, + client_memory_gc_rounds: int = 0, + cuda_empty_cache: bool = False, ): # Store PyTorch-specific model_locator before calling parent self._pt_model_locator = model_locator @@ -158,6 +160,8 @@ def __init__( aggregation_weights=aggregation_weights, server_memory_gc_rounds=server_memory_gc_rounds, enable_tensor_disk_offload=enable_tensor_disk_offload, + client_memory_gc_rounds=client_memory_gc_rounds, + cuda_empty_cache=cuda_empty_cache, ) def _setup_model_and_persistor(self, job) -> str: diff --git a/nvflare/app_opt/pt/recipes/fedavg_he.py b/nvflare/app_opt/pt/recipes/fedavg_he.py index 44b659693d..d9ac085096 100644 --- a/nvflare/app_opt/pt/recipes/fedavg_he.py +++ b/nvflare/app_opt/pt/recipes/fedavg_he.py @@ -59,6 +59,8 @@ class _FedAvgRecipeWithHEValidator(BaseModel): params_transfer_type: TransferType = TransferType.FULL encrypt_layers: Optional[Union[List[str], str]] = None server_memory_gc_rounds: int = 1 + client_memory_gc_rounds: int = 0 + cuda_empty_cache: bool = False class FedAvgRecipeWithHE(Recipe): @@ -169,6 +171,8 @@ def __init__( params_transfer_type: TransferType = TransferType.FULL, encrypt_layers: Optional[Union[List[str], str]] = None, server_memory_gc_rounds: int = 1, + client_memory_gc_rounds: int = 0, + cuda_empty_cache: bool = False, ): # Validate inputs internally v = _FedAvgRecipeWithHEValidator( @@ -187,6 +191,8 @@ def __init__( params_transfer_type=params_transfer_type, encrypt_layers=encrypt_layers, server_memory_gc_rounds=server_memory_gc_rounds, + client_memory_gc_rounds=client_memory_gc_rounds, + cuda_empty_cache=cuda_empty_cache, ) self.name = v.name @@ -212,6 +218,8 @@ def __init__( self.params_transfer_type: TransferType = v.params_transfer_type self.encrypt_layers: Optional[Union[List[str], str]] = v.encrypt_layers self.server_memory_gc_rounds = v.server_memory_gc_rounds + self.client_memory_gc_rounds = v.client_memory_gc_rounds + self.cuda_empty_cache = v.cuda_empty_cache # Create BaseFedJob without model first (model setup done manually below for HE) job = BaseFedJob( @@ -278,6 +286,8 @@ def __init__( framework=FrameworkType.PYTORCH, server_expected_format=self.server_expected_format, params_transfer_type=self.params_transfer_type, + memory_gc_rounds=self.client_memory_gc_rounds, + cuda_empty_cache=self.cuda_empty_cache, ) job.to_clients(executor) diff --git a/nvflare/app_opt/pt/recipes/fedeval.py b/nvflare/app_opt/pt/recipes/fedeval.py index ecfbf23b5e..d2b2564d1d 100644 --- a/nvflare/app_opt/pt/recipes/fedeval.py +++ b/nvflare/app_opt/pt/recipes/fedeval.py @@ -130,6 +130,8 @@ def __init__( server_expected_format: ExchangeFormat = ExchangeFormat.NUMPY, validation_timeout: int = 6000, per_site_config: Optional[Dict[str, Dict]] = None, + client_memory_gc_rounds: int = 0, + cuda_empty_cache: bool = False, ): # Validate eval_ckpt _FedEvalValidator(eval_ckpt=eval_ckpt) @@ -145,6 +147,8 @@ def __init__( self.server_expected_format = server_expected_format self.validation_timeout = validation_timeout self.per_site_config = per_site_config + self.client_memory_gc_rounds = client_memory_gc_rounds + self.cuda_empty_cache = cuda_empty_cache # Create BaseFedJob job = BaseFedJob( @@ -197,6 +201,8 @@ def __init__( command=cmd, framework=FrameworkType.PYTORCH, server_expected_format=expected_format, + memory_gc_rounds=self.client_memory_gc_rounds, + cuda_empty_cache=self.cuda_empty_cache, ) job.to(executor, site_name) else: @@ -207,6 +213,8 @@ def __init__( command=self.command, framework=FrameworkType.PYTORCH, server_expected_format=self.server_expected_format, + memory_gc_rounds=self.client_memory_gc_rounds, + cuda_empty_cache=self.cuda_empty_cache, ) job.to_clients(executor) diff --git a/nvflare/app_opt/pt/recipes/fedopt.py b/nvflare/app_opt/pt/recipes/fedopt.py index 44d05dc83a..12b1ed489a 100644 --- a/nvflare/app_opt/pt/recipes/fedopt.py +++ b/nvflare/app_opt/pt/recipes/fedopt.py @@ -47,6 +47,8 @@ class _FedOptValidator(BaseModel): server_expected_format: ExchangeFormat = ExchangeFormat.NUMPY device: Optional[str] = None server_memory_gc_rounds: int = 1 + client_memory_gc_rounds: int = 0 + cuda_empty_cache: bool = False class FedOptRecipe(Recipe): @@ -135,6 +137,8 @@ def __init__( optimizer_args: Optional[dict] = None, lr_scheduler_args: Optional[dict] = None, server_memory_gc_rounds: int = 1, + client_memory_gc_rounds: int = 0, + cuda_empty_cache: bool = False, ): # Validate inputs internally v = _FedOptValidator( @@ -151,6 +155,8 @@ def __init__( server_expected_format=server_expected_format, device=device, server_memory_gc_rounds=server_memory_gc_rounds, + client_memory_gc_rounds=client_memory_gc_rounds, + cuda_empty_cache=cuda_empty_cache, ) self.name = v.name @@ -179,6 +185,8 @@ def __init__( self.optimizer_args = ensure_config_type_dict(optimizer_args) self.lr_scheduler_args = ensure_config_type_dict(lr_scheduler_args) self.server_memory_gc_rounds = v.server_memory_gc_rounds + self.client_memory_gc_rounds = v.client_memory_gc_rounds + self.cuda_empty_cache = v.cuda_empty_cache # Replace {num_rounds} placeholder if present in lr_scheduler_args processed_lr_scheduler_args = None @@ -276,6 +284,8 @@ def __init__( framework=FrameworkType.PYTORCH, server_expected_format=self.server_expected_format, params_transfer_type=TransferType.DIFF, + memory_gc_rounds=self.client_memory_gc_rounds, + cuda_empty_cache=self.cuda_empty_cache, ) job.to_clients(executor) diff --git a/nvflare/app_opt/pt/recipes/scaffold.py b/nvflare/app_opt/pt/recipes/scaffold.py index fe4c10efdd..81f2592ee8 100644 --- a/nvflare/app_opt/pt/recipes/scaffold.py +++ b/nvflare/app_opt/pt/recipes/scaffold.py @@ -41,6 +41,8 @@ class _ScaffoldValidator(BaseModel): server_expected_format: ExchangeFormat = ExchangeFormat.NUMPY params_transfer_type: TransferType = TransferType.FULL server_memory_gc_rounds: int = 0 + client_memory_gc_rounds: int = 0 + cuda_empty_cache: bool = False class ScaffoldRecipe(Recipe): @@ -102,6 +104,8 @@ def __init__( server_expected_format: ExchangeFormat = ExchangeFormat.NUMPY, params_transfer_type: TransferType = TransferType.FULL, server_memory_gc_rounds: int = 0, + client_memory_gc_rounds: int = 0, + cuda_empty_cache: bool = False, ): # Validate inputs internally v = _ScaffoldValidator( @@ -117,6 +121,8 @@ def __init__( server_expected_format=server_expected_format, params_transfer_type=params_transfer_type, server_memory_gc_rounds=server_memory_gc_rounds, + client_memory_gc_rounds=client_memory_gc_rounds, + cuda_empty_cache=cuda_empty_cache, ) self.name = v.name @@ -139,6 +145,8 @@ def __init__( self.server_expected_format: ExchangeFormat = v.server_expected_format self.params_transfer_type: TransferType = v.params_transfer_type self.server_memory_gc_rounds = v.server_memory_gc_rounds + self.client_memory_gc_rounds = v.client_memory_gc_rounds + self.cuda_empty_cache = v.cuda_empty_cache # Create BaseFedJob job = BaseFedJob( @@ -176,6 +184,8 @@ def __init__( framework=FrameworkType.PYTORCH, server_expected_format=self.server_expected_format, params_transfer_type=self.params_transfer_type, + memory_gc_rounds=self.client_memory_gc_rounds, + cuda_empty_cache=self.cuda_empty_cache, ) job.to_clients(executor) diff --git a/nvflare/app_opt/pt/recipes/swarm.py b/nvflare/app_opt/pt/recipes/swarm.py index 4841c4a6d7..1bee35b57f 100644 --- a/nvflare/app_opt/pt/recipes/swarm.py +++ b/nvflare/app_opt/pt/recipes/swarm.py @@ -116,6 +116,10 @@ class SimpleSwarmLearningRecipe(BaseSwarmLearningRecipe): train_args: Additional arguments for the training script. do_cross_site_eval: Whether to perform cross-site evaluation. cross_site_eval_timeout: Timeout for cross-site evaluation. + memory_gc_rounds: Run gc.collect() + malloc_trim every N FL rounds on both the trainer + and aggregator roles. Defaults to 1 (every round) to match legacy behavior where + gc.collect() was called unconditionally after each trainer submission. Set to 0 to disable. + cuda_empty_cache: Call torch.cuda.empty_cache() during cleanup. Defaults to False. Example: Using nn.Module instance: @@ -151,6 +155,8 @@ def __init__( train_args: dict = None, do_cross_site_eval: bool = False, cross_site_eval_timeout: float = 300, + memory_gc_rounds: int = 1, + cuda_empty_cache: bool = False, ): _SwarmValidator(initial_ckpt=initial_ckpt) @@ -173,7 +179,14 @@ def __init__( train_args = {} else: # Validate train_args doesn't conflict with ScriptRunner reserved parameters - reserved_keys = {"script", "launch_external_process", "command", "framework"} + reserved_keys = { + "script", + "launch_external_process", + "command", + "framework", + "memory_gc_rounds", + "cuda_empty_cache", + } conflicts = set(train_args.keys()) & reserved_keys if conflicts: raise ValueError(f"train_args contains reserved keys that conflict with ScriptRunner: {conflicts}") @@ -186,10 +199,17 @@ def __init__( server_config = SwarmServerConfig(num_rounds=num_rounds) client_config = SwarmClientConfig( - executor=ScriptRunner(script=train_script, **train_args), + executor=ScriptRunner( + script=train_script, + memory_gc_rounds=memory_gc_rounds, + cuda_empty_cache=cuda_empty_cache, + **train_args, + ), aggregator=aggregator, persistor=PTFileModelPersistor(model=model_instance, source_ckpt_file_full_name=ckpt_path), shareable_generator=SimpleModelShareableGenerator(), + memory_gc_rounds=memory_gc_rounds, + cuda_empty_cache=cuda_empty_cache, ) BaseSwarmLearningRecipe.__init__(self, name, server_config, client_config, cse_config, job=job) diff --git a/nvflare/app_opt/tf/client_api_launcher_executor.py b/nvflare/app_opt/tf/client_api_launcher_executor.py index 33a20aac50..a3c06f9df5 100644 --- a/nvflare/app_opt/tf/client_api_launcher_executor.py +++ b/nvflare/app_opt/tf/client_api_launcher_executor.py @@ -47,6 +47,8 @@ def __init__( params_exchange_format: str = ExchangeFormat.KERAS_LAYER_WEIGHTS, params_transfer_type: str = TransferType.FULL, config_file_name: str = CLIENT_API_CONFIG, + memory_gc_rounds: int = 0, + cuda_empty_cache: bool = False, ) -> None: ClientAPILauncherExecutor.__init__( self, @@ -72,6 +74,8 @@ def __init__( params_exchange_format=params_exchange_format, params_transfer_type=params_transfer_type, config_file_name=config_file_name, + memory_gc_rounds=memory_gc_rounds, + cuda_empty_cache=cuda_empty_cache, ) def initialize(self, fl_ctx: FLContext) -> None: diff --git a/nvflare/app_opt/tf/in_process_client_api_executor.py b/nvflare/app_opt/tf/in_process_client_api_executor.py index 3beb6af775..a0ae17bc21 100644 --- a/nvflare/app_opt/tf/in_process_client_api_executor.py +++ b/nvflare/app_opt/tf/in_process_client_api_executor.py @@ -37,6 +37,8 @@ def __init__( submit_model_task_name: str = AppConstants.TASK_SUBMIT_MODEL, params_exchange_format=ExchangeFormat.KERAS_LAYER_WEIGHTS, server_expected_format=ExchangeFormat.NUMPY, + memory_gc_rounds: int = 0, + cuda_empty_cache: bool = False, ): super(TFInProcessClientAPIExecutor, self).__init__( task_script_path=task_script_path, @@ -53,6 +55,8 @@ def __init__( params_transfer_type=params_transfer_type, log_pull_interval=log_pull_interval, server_expected_format=server_expected_format, + memory_gc_rounds=memory_gc_rounds, + cuda_empty_cache=cuda_empty_cache, ) if ( diff --git a/nvflare/app_opt/tf/recipes/cyclic.py b/nvflare/app_opt/tf/recipes/cyclic.py index bc531295fa..7aa598a9f5 100644 --- a/nvflare/app_opt/tf/recipes/cyclic.py +++ b/nvflare/app_opt/tf/recipes/cyclic.py @@ -62,6 +62,7 @@ def __init__( server_expected_format: ExchangeFormat = ExchangeFormat.NUMPY, params_transfer_type: TransferType = TransferType.FULL, server_memory_gc_rounds: int = 1, + client_memory_gc_rounds: int = 0, ): # Validate initial_ckpt early (base class won't see it since we pass None) from nvflare.recipe.utils import validate_ckpt @@ -93,6 +94,8 @@ def __init__( server_expected_format=server_expected_format, params_transfer_type=params_transfer_type, server_memory_gc_rounds=server_memory_gc_rounds, + client_memory_gc_rounds=client_memory_gc_rounds, + cuda_empty_cache=False, ) def _setup_model_and_persistor(self, job) -> str: diff --git a/nvflare/app_opt/tf/recipes/fedavg.py b/nvflare/app_opt/tf/recipes/fedavg.py index 4b632d7f72..abba109c6a 100644 --- a/nvflare/app_opt/tf/recipes/fedavg.py +++ b/nvflare/app_opt/tf/recipes/fedavg.py @@ -118,6 +118,7 @@ def __init__( shutdown_timeout: float = 0.0, key_metric: str = "accuracy", server_memory_gc_rounds: int = 0, + client_memory_gc_rounds: int = 0, ): # Call the unified FedAvgRecipe with TensorFlow-specific settings super().__init__( @@ -141,6 +142,8 @@ def __init__( shutdown_timeout=shutdown_timeout, key_metric=key_metric, server_memory_gc_rounds=server_memory_gc_rounds, + client_memory_gc_rounds=client_memory_gc_rounds, + cuda_empty_cache=False, ) def _setup_model_and_persistor(self, job) -> str: diff --git a/nvflare/app_opt/tf/recipes/fedopt.py b/nvflare/app_opt/tf/recipes/fedopt.py index effd90cbb7..46f25b6d10 100644 --- a/nvflare/app_opt/tf/recipes/fedopt.py +++ b/nvflare/app_opt/tf/recipes/fedopt.py @@ -42,6 +42,7 @@ class _FedOptValidator(BaseModel): optimizer_args: Optional[dict] = None lr_scheduler_args: Optional[dict] = None server_memory_gc_rounds: int = 0 + client_memory_gc_rounds: int = 0 class FedOptRecipe(Recipe): @@ -130,6 +131,7 @@ def __init__( optimizer_args: Optional[dict] = None, lr_scheduler_args: Optional[dict] = None, server_memory_gc_rounds: int = 0, + client_memory_gc_rounds: int = 0, ): # Validate inputs internally v = _FedOptValidator( @@ -147,6 +149,7 @@ def __init__( optimizer_args=optimizer_args, lr_scheduler_args=lr_scheduler_args, server_memory_gc_rounds=server_memory_gc_rounds, + client_memory_gc_rounds=client_memory_gc_rounds, ) self.name = v.name @@ -171,6 +174,7 @@ def __init__( self.optimizer_args = ensure_config_type_dict(v.optimizer_args) self.lr_scheduler_args = ensure_config_type_dict(v.lr_scheduler_args) self.server_memory_gc_rounds = v.server_memory_gc_rounds + self.client_memory_gc_rounds = v.client_memory_gc_rounds # Create BaseFedJob job = BaseFedJob( @@ -207,6 +211,8 @@ def __init__( framework=FrameworkType.TENSORFLOW, server_expected_format=self.server_expected_format, params_transfer_type=self.params_transfer_type, + memory_gc_rounds=self.client_memory_gc_rounds, + cuda_empty_cache=False, ) job.to_clients(executor) diff --git a/nvflare/app_opt/tf/recipes/scaffold.py b/nvflare/app_opt/tf/recipes/scaffold.py index 0365df0b39..44c3f51b85 100644 --- a/nvflare/app_opt/tf/recipes/scaffold.py +++ b/nvflare/app_opt/tf/recipes/scaffold.py @@ -40,6 +40,7 @@ class _ScaffoldValidator(BaseModel): server_expected_format: ExchangeFormat = ExchangeFormat.NUMPY params_transfer_type: TransferType = TransferType.FULL server_memory_gc_rounds: int = 0 + client_memory_gc_rounds: int = 0 class ScaffoldRecipe(Recipe): @@ -125,6 +126,7 @@ def __init__( server_expected_format: ExchangeFormat = ExchangeFormat.NUMPY, params_transfer_type: TransferType = TransferType.FULL, server_memory_gc_rounds: int = 0, + client_memory_gc_rounds: int = 0, ): # Validate inputs internally v = _ScaffoldValidator( @@ -140,6 +142,7 @@ def __init__( server_expected_format=server_expected_format, params_transfer_type=params_transfer_type, server_memory_gc_rounds=server_memory_gc_rounds, + client_memory_gc_rounds=client_memory_gc_rounds, ) self.name = v.name @@ -162,6 +165,7 @@ def __init__( self.server_expected_format: ExchangeFormat = v.server_expected_format self.params_transfer_type: TransferType = v.params_transfer_type self.server_memory_gc_rounds = v.server_memory_gc_rounds + self.client_memory_gc_rounds = v.client_memory_gc_rounds # Create BaseFedJob with initial model job = BaseFedJob( @@ -189,6 +193,8 @@ def __init__( framework=FrameworkType.TENSORFLOW, server_expected_format=self.server_expected_format, params_transfer_type=self.params_transfer_type, + memory_gc_rounds=self.client_memory_gc_rounds, + cuda_empty_cache=False, ) job.to_clients(executor) diff --git a/nvflare/client/api_spec.py b/nvflare/client/api_spec.py index 458976638e..0e4146d762 100644 --- a/nvflare/client/api_spec.py +++ b/nvflare/client/api_spec.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from abc import ABC, abstractmethod from typing import Any, Dict, Optional @@ -23,6 +24,35 @@ class APISpec(ABC): + """Abstract base class for NVFlare Client APIs. + + Provides shared memory management functionality for subclasses. + """ + + def __init__(self): + """Initialize memory management attributes.""" + self._memory_gc_rounds = 0 + self._cuda_empty_cache = False + self._round_count = 0 + self._memory_logger = logging.getLogger(self.__class__.__name__) + + def _maybe_cleanup_memory(self): + """Perform memory cleanup if configured (every N rounds). + + This method is called after send() to periodically clean up memory. + Cleanup is only performed if _memory_gc_rounds > 0 and the current + round count is a multiple of _memory_gc_rounds. + """ + if self._memory_gc_rounds <= 0: + return + + self._round_count += 1 + if self._round_count % self._memory_gc_rounds == 0: + from nvflare.fuel.utils.memory_utils import cleanup_memory + + cleanup_memory(cuda_empty_cache=self._cuda_empty_cache) + self._memory_logger.debug(f"Memory cleanup performed at round {self._round_count}") + @abstractmethod def init(self, rank: Optional[str] = None): """Initializes NVFlare Client API environment. diff --git a/nvflare/client/config.py b/nvflare/client/config.py index 63efb24fb5..a14a6c1054 100644 --- a/nvflare/client/config.py +++ b/nvflare/client/config.py @@ -48,6 +48,8 @@ class ConfigKey: TASK_EXCHANGE = "TASK_EXCHANGE" METRICS_EXCHANGE = "METRICS_EXCHANGE" HEARTBEAT_TIMEOUT = "HEARTBEAT_TIMEOUT" + MEMORY_GC_ROUNDS = "memory_gc_rounds" + CUDA_EMPTY_CACHE = "cuda_empty_cache" class ClientConfig: diff --git a/nvflare/client/ex_process/api.py b/nvflare/client/ex_process/api.py index 95f0572b5e..edff7d1af4 100644 --- a/nvflare/client/ex_process/api.py +++ b/nvflare/client/ex_process/api.py @@ -83,11 +83,14 @@ def _register_tensor_decomposer(): class ExProcessClientAPI(APISpec): def __init__(self, config_file: str): + super().__init__() # Initialize memory management from base class + self.model_registry = None self.logger = get_obj_logger(self) self.receive_called = False self.config_file = config_file self.flare_agent = None + # Memory settings will be read from config in init() def get_model_registry(self) -> ModelRegistry: """Gets the ModelRegistry.""" @@ -140,6 +143,20 @@ def init(self, rank: Optional[str] = None): self.model_registry = ModelRegistry(client_config, rank, flare_agent) self.flare_agent = flare_agent + + # Read memory management settings from config (with env var override) + task_exchange = client_config.config.get(ConfigKey.TASK_EXCHANGE, {}) + config_gc_rounds = task_exchange.get(ConfigKey.MEMORY_GC_ROUNDS, 0) + config_cuda_cache = task_exchange.get(ConfigKey.CUDA_EMPTY_CACHE, False) + + # Environment variables override config values. + self._memory_gc_rounds = int(os.environ.get("NVFLARE_CLIENT_MEMORY_GC_ROUNDS", str(config_gc_rounds))) + self._cuda_empty_cache = ( + os.environ.get("NVFLARE_CUDA_EMPTY_CACHE", str(config_cuda_cache)).lower() == "true" + ) + + if self._memory_gc_rounds > 0: + self.logger.info(f"Memory management enabled: cleanup every {self._memory_gc_rounds} round(s)") except Exception as e: self.logger.error(f"flare.init failed: {e}") raise e @@ -161,6 +178,9 @@ def send(self, model: FLModel, clear_cache: bool = True) -> None: if clear_cache: self.clear() + # Perform memory cleanup if configured + self._maybe_cleanup_memory() + def system_info(self) -> Dict: model_registry = self.get_model_registry() return model_registry.get_sys_info() diff --git a/nvflare/client/in_process/api.py b/nvflare/client/in_process/api.py index 5993b61af9..9cef28f75d 100644 --- a/nvflare/client/in_process/api.py +++ b/nvflare/client/in_process/api.py @@ -44,6 +44,8 @@ def __init__(self, task_metadata: dict, result_check_interval: float = 2.0): task_metadata (dict): task metadata, added to client_config. result_check_interval (float): how often to check if result is available. """ + super().__init__() # Initialize memory management from base class + self.data_bus = DataBus() self.data_bus.subscribe([TOPIC_GLOBAL_RESULT], self.__receive_callback) self.data_bus.subscribe([TOPIC_ABORT, TOPIC_STOP], self.__ask_to_abort) @@ -98,6 +100,18 @@ def prepare_client_config(self, config): def set_meta(self, meta: dict): self.meta = meta + def configure_memory_management(self, gc_rounds: int = 0, cuda_empty_cache: bool = False): + """Configure memory management settings. + + Args: + gc_rounds: Cleanup every N rounds. 0 = disabled. + cuda_empty_cache: If True, call torch.cuda.empty_cache() on cleanup. + """ + self._memory_gc_rounds = gc_rounds + self._cuda_empty_cache = cuda_empty_cache + if gc_rounds > 0: + self.logger.info(f"Memory management enabled: cleanup every {gc_rounds} round(s)") + def receive(self, timeout: Optional[float] = None) -> Optional[FLModel]: result = self.__receive() self.receive_called = True @@ -139,6 +153,9 @@ def send(self, model: FLModel, clear_cache: bool = True) -> None: self.fl_model = None self.receive_called = False + # Perform memory cleanup if configured + self._maybe_cleanup_memory() + def system_info(self) -> Dict: return self.sys_info diff --git a/nvflare/fuel/utils/memory_utils.py b/nvflare/fuel/utils/memory_utils.py index 10f1d132c5..187a2db805 100644 --- a/nvflare/fuel/utils/memory_utils.py +++ b/nvflare/fuel/utils/memory_utils.py @@ -15,19 +15,28 @@ """Memory management utilities for federated learning. This module provides memory cleanup utilities to help manage RSS (Resident Set Size) -in long-running FL jobs using Python + PyTorch + glibc. +in long-running FL jobs using Python + PyTorch + glibc/jemalloc. + +Allocator Support: +- glibc: Uses malloc_trim() to return freed pages to OS +- jemalloc: Relies on auto-decay (MALLOC_CONF), no manual action needed Best Practices: - Client: Set MALLOC_ARENA_MAX=2, cleanup every round - Server: Set MALLOC_ARENA_MAX=4, cleanup every 5 rounds +- jemalloc: Set MALLOC_CONF="dirty_decay_ms:5000,muzzy_decay_ms:5000" Usage: - from nvflare.fuel.utils.memory_utils import cleanup_memory + from nvflare.fuel.utils.memory_utils import cleanup_memory, get_allocator_type + + # Check which allocator is in use + allocator = get_allocator_type() # "glibc", "jemalloc", or "unknown" # At end of each round (client) or every N rounds (server) - cleanup_memory(torch_cuda_empty_cache=True) # True for PyTorch GPU clients + cleanup_memory(cuda_empty_cache=True) # True for PyTorch GPU clients """ +import ctypes import gc import logging from ctypes import CDLL, c_size_t @@ -56,6 +65,37 @@ def _get_glibc() -> Optional[CDLL]: return None +@lru_cache(maxsize=1) +def get_allocator_type() -> str: + """Detect which memory allocator is in use at runtime. + + Returns: + "jemalloc": jemalloc is loaded (recommended for PyTorch) + "glibc": Standard glibc malloc is in use + "unknown": Could not detect allocator type + + Note: + - jemalloc is typically loaded via LD_PRELOAD + - Detection is cached after first call + - Safe to call frequently (no overhead after first call) + """ + try: + # Load the C library that the process is using + libc = ctypes.CDLL(None) + + # jemalloc has mallctl function + if hasattr(libc, "mallctl"): + return "jemalloc" + + # glibc has malloc_trim + if hasattr(libc, "malloc_trim"): + return "glibc" + except Exception: + pass + + return "unknown" + + def try_malloc_trim() -> Optional[int]: """Attempt to release free memory back to the OS (glibc only). @@ -80,31 +120,42 @@ def try_malloc_trim() -> Optional[int]: return None -def cleanup_memory(torch_cuda_empty_cache: bool = False) -> None: - """Perform memory cleanup to reduce RSS. +def cleanup_memory(cuda_empty_cache: bool = False) -> None: + """Perform allocator-aware memory cleanup to reduce RSS. This function: 1. Runs Python garbage collection (gc.collect) - 2. Releases free heap pages to OS (malloc_trim, Linux/glibc only) + 2. For glibc: Releases free heap pages to OS (malloc_trim) + For jemalloc: Relies on auto-decay (no manual action needed) 3. Optionally clears PyTorch CUDA cache Args: - torch_cuda_empty_cache: If True, also call torch.cuda.empty_cache(). + cuda_empty_cache: If True, also call torch.cuda.empty_cache(). Only applicable to PyTorch GPU clients. Note: Call this at the end of each FL round (client) or every N rounds (server). + The function automatically detects the allocator type and applies + the appropriate cleanup strategy. """ - # Step 1: Python garbage collection + # Step 1: Python garbage collection (always) gc.collect() - # Step 2: Return free heap pages to OS (glibc only) - result = try_malloc_trim() - if result is not None: - logger.debug(f"malloc_trim returned {result}") + # Step 2: Allocator-specific cleanup + allocator = get_allocator_type() + if allocator == "glibc": + # glibc: manually return freed pages to OS + result = try_malloc_trim() + if result is not None: + logger.debug(f"malloc_trim returned {result}") + elif allocator == "jemalloc": + # jemalloc: auto-decay handles memory return, no manual action needed + # Memory is returned based on MALLOC_CONF settings (dirty_decay_ms, muzzy_decay_ms) + logger.debug("jemalloc detected, relying on auto-decay for memory management") + # unknown: gc.collect() is the only safe action # Step 3: Clear PyTorch CUDA cache if requested - if torch_cuda_empty_cache: + if cuda_empty_cache: try: import torch diff --git a/nvflare/job_config/script_runner.py b/nvflare/job_config/script_runner.py index deadcd0566..455c63049e 100644 --- a/nvflare/job_config/script_runner.py +++ b/nvflare/job_config/script_runner.py @@ -62,6 +62,8 @@ def __init__( pipe_connect_type: str = None, launch_once: bool = True, shutdown_timeout: float = 0.0, + memory_gc_rounds: int = 0, + cuda_empty_cache: bool = False, ): """BaseScriptRunner is used with FedJob API to run or launch a script. @@ -172,6 +174,8 @@ def __init__( self._task_pipe = task_pipe self._executor = executor self._launcher = launcher + self._memory_gc_rounds = memory_gc_rounds + self._cuda_empty_cache = cuda_empty_cache def _create_cell_pipe(self): ct = self._pipe_connect_type @@ -230,6 +234,8 @@ def add_to_fed_job(self, job: FedJob, ctx, **kwargs): params_exchange_format=self._params_exchange_format, params_transfer_type=self._params_transfer_type, server_expected_format=self._server_expected_format, + memory_gc_rounds=self._memory_gc_rounds, + cuda_empty_cache=self._cuda_empty_cache, ) ) job.add_executor(executor, tasks=tasks, ctx=ctx) @@ -264,6 +270,8 @@ def add_to_fed_job(self, job: FedJob, ctx, **kwargs): params_exchange_format=self._params_exchange_format, params_transfer_type=self._params_transfer_type, server_expected_format=self._server_expected_format, + memory_gc_rounds=self._memory_gc_rounds, + cuda_empty_cache=self._cuda_empty_cache, ) ) job.add_executor(executor, tasks=tasks, ctx=ctx) @@ -309,6 +317,8 @@ def __init__( pipe_connect_type: PipeConnectType = PipeConnectType.VIA_CP, launch_once: bool = True, shutdown_timeout: float = 0.0, + memory_gc_rounds: int = 0, + cuda_empty_cache: bool = False, ): """ScriptRunner is used with FedJob API to run or launch a script. @@ -341,4 +351,6 @@ def __init__( pipe_connect_type=pipe_connect_type, launch_once=launch_once, shutdown_timeout=shutdown_timeout, + memory_gc_rounds=memory_gc_rounds, + cuda_empty_cache=cuda_empty_cache, ) diff --git a/nvflare/lighter/templates/master_template.yml b/nvflare/lighter/templates/master_template.yml index 59c8b33f4f..077a0cd73c 100644 --- a/nvflare/lighter/templates/master_template.yml +++ b/nvflare/lighter/templates/master_template.yml @@ -8,7 +8,7 @@ readme_am: | client.crt client.key fl_admin.sh - + Please install the nvflare package by 'python3 -m pip nvflare.' This will install a set of Python codes in your environment. After installation, you can run the fl_admin.sh file to start communicating to the admin server. @@ -538,7 +538,7 @@ start_svr_sh: | else $DIR/sub_start.sh & fi - + stop_fl_sh: | #!/usr/bin/env bash DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" @@ -566,8 +566,24 @@ sub_start_sh: | mkdir -p $DIR/../transfer export PYTHONPATH=/local/custom:$PYTHONPATH echo "PYTHONPATH is $PYTHONPATH" - - # Memory management: limit glibc memory arenas to reduce RSS fragmentation + + # Memory management configuration + # Opt-in jemalloc preload for PyTorch workloads. + # Enable by setting NVFLARE_ENABLE_JEMALLOC_PRELOAD=true. + if [ "${NVFLARE_ENABLE_JEMALLOC_PRELOAD:-false}" = "true" ]; then + for JEMALLOC in /usr/lib/x86_64-linux-gnu/libjemalloc.so.2 \ + /usr/lib64/libjemalloc.so.2 \ + /usr/local/lib/libjemalloc.so; do + if [ -f "$JEMALLOC" ]; then + export LD_PRELOAD="${LD_PRELOAD:+$LD_PRELOAD:}$JEMALLOC" + export MALLOC_CONF="${MALLOC_CONF:-dirty_decay_ms:5000,muzzy_decay_ms:5000}" + echo "Using jemalloc: $JEMALLOC" + break + fi + done + fi + + # Limit glibc memory arenas to reduce RSS fragmentation (ignored by jemalloc) # Recommended: MALLOC_ARENA_MAX=2 for clients, MALLOC_ARENA_MAX=4 for servers export MALLOC_ARENA_MAX=${MALLOC_ARENA_MAX:-4} @@ -588,11 +604,11 @@ sub_start_sh: | SECONDS=0 lst=-400 restart_count=0 - + start_python() { python3 -u -m nvflare.private.fed.app.{~~type~~}.{~~app_name~~} -m $DIR/.. -s fed_{~~type~~}.json --set secure_train=true {~~cln_uid~~} org={~~org_name~~} config_folder={~~config_folder~~} } - + start_fl() { if [[ $(( $SECONDS - $lst )) -lt 300 ]]; then ((restart_count++)) @@ -625,7 +641,7 @@ sub_start_sh: | kill -9 $pid rm -f $DIR/../pid.fl $DIR/../shutdown.fl $DIR/../restart.fl 2> /dev/null 1>&2 } - + if [[ "$doVerify" == "true" ]]; then python3 -m nvflare.tool.verify_startup_kits -f $DIR/../ -c $DIR/rootCA.pem verification_status=$? @@ -699,10 +715,10 @@ docker_cln_sh: | # export MY_DATA_DIR=$SOME_DIRECTORY # before running docker.sh - # for all gpus use line below + # for all gpus use line below #GPU2USE='--gpus=all' # for 2 gpus use line below - #GPU2USE='--gpus=2' + #GPU2USE='--gpus=2' # for specific gpus as gpu#0 and gpu#2 use line below #GPU2USE='--gpus="device=0,2"' # to use host network, use line below @@ -956,7 +972,7 @@ helm_chart_values: | cloud_script_header: | #!/usr/bin/env bash - + DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" function report_status() { status="$1" diff --git a/nvflare/recipe/cyclic.py b/nvflare/recipe/cyclic.py index e2406c38fb..44bc30a658 100644 --- a/nvflare/recipe/cyclic.py +++ b/nvflare/recipe/cyclic.py @@ -42,6 +42,8 @@ class _CyclicValidator(BaseModel): params_transfer_type: TransferType = TransferType.FULL framework: FrameworkType = FrameworkType.NUMPY server_memory_gc_rounds: int = 1 + client_memory_gc_rounds: int = 0 + cuda_empty_cache: bool = False class CyclicRecipe(Recipe): @@ -112,6 +114,8 @@ def __init__( server_expected_format: ExchangeFormat = ExchangeFormat.NUMPY, params_transfer_type: TransferType = TransferType.FULL, server_memory_gc_rounds: int = 1, + client_memory_gc_rounds: int = 0, + cuda_empty_cache: bool = False, ): # Validate inputs internally v = _CyclicValidator( @@ -128,6 +132,8 @@ def __init__( server_expected_format=server_expected_format, params_transfer_type=params_transfer_type, server_memory_gc_rounds=server_memory_gc_rounds, + client_memory_gc_rounds=client_memory_gc_rounds, + cuda_empty_cache=cuda_empty_cache, ) self.name = v.name @@ -150,6 +156,8 @@ def __init__( self.server_expected_format: ExchangeFormat = v.server_expected_format self.params_transfer_type: TransferType = v.params_transfer_type self.server_memory_gc_rounds = v.server_memory_gc_rounds + self.client_memory_gc_rounds = v.client_memory_gc_rounds + self.cuda_empty_cache = v.cuda_empty_cache # Validate that we have at least one model source if self.model is None and self.initial_ckpt is None: @@ -188,6 +196,8 @@ def __init__( framework=self.framework, server_expected_format=self.server_expected_format, params_transfer_type=self.params_transfer_type, + memory_gc_rounds=self.client_memory_gc_rounds, + cuda_empty_cache=self.cuda_empty_cache, ) job.to_clients(executor) diff --git a/nvflare/recipe/fedavg.py b/nvflare/recipe/fedavg.py index fd59424413..be1dbee2d0 100644 --- a/nvflare/recipe/fedavg.py +++ b/nvflare/recipe/fedavg.py @@ -61,6 +61,8 @@ class _FedAvgValidator(BaseModel): # Memory management server_memory_gc_rounds: int = 0 enable_tensor_disk_offload: bool = False + client_memory_gc_rounds: int = 0 + cuda_empty_cache: bool = False class FedAvgRecipe(Recipe): @@ -181,6 +183,8 @@ def __init__( aggregation_weights: Optional[Dict[str, float]] = None, server_memory_gc_rounds: int = 0, enable_tensor_disk_offload: bool = False, + client_memory_gc_rounds: int = 0, + cuda_empty_cache: bool = False, ): # Validate inputs internally v = _FedAvgValidator( @@ -210,6 +214,8 @@ def __init__( aggregation_weights=aggregation_weights, server_memory_gc_rounds=server_memory_gc_rounds, enable_tensor_disk_offload=enable_tensor_disk_offload, + client_memory_gc_rounds=client_memory_gc_rounds, + cuda_empty_cache=cuda_empty_cache, ) self.name = v.name @@ -246,6 +252,8 @@ def __init__( self.aggregation_weights = v.aggregation_weights self.server_memory_gc_rounds = v.server_memory_gc_rounds self.enable_tensor_disk_offload = v.enable_tensor_disk_offload + self.client_memory_gc_rounds = v.client_memory_gc_rounds + self.cuda_empty_cache = v.cuda_empty_cache # Validate that we have at least one model source # Note: Subclasses (e.g., sklearn) that manage models differently should pass @@ -334,6 +342,8 @@ def __init__( params_transfer_type=transfer_type, launch_once=launch_once, shutdown_timeout=shutdown_timeout, + memory_gc_rounds=self.client_memory_gc_rounds, + cuda_empty_cache=self.cuda_empty_cache, ) job.to(executor, site_name) else: @@ -347,6 +357,8 @@ def __init__( params_transfer_type=self.params_transfer_type, launch_once=self.launch_once, shutdown_timeout=self.shutdown_timeout, + memory_gc_rounds=self.client_memory_gc_rounds, + cuda_empty_cache=self.cuda_empty_cache, ) job.to_clients(executor) diff --git a/tests/unit_test/app_common/ccwf/test_swarm_memory_gc.py b/tests/unit_test/app_common/ccwf/test_swarm_memory_gc.py new file mode 100644 index 0000000000..82a2516db6 --- /dev/null +++ b/tests/unit_test/app_common/ccwf/test_swarm_memory_gc.py @@ -0,0 +1,129 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for SwarmClientController per-round aggregator GC cadence.""" + +from unittest.mock import Mock, patch + +from nvflare.apis.shareable import Shareable +from nvflare.app_common.app_constant import AppConstants +from nvflare.app_common.ccwf.common import Constant +from nvflare.app_common.ccwf.swarm_client_ctl import SwarmClientController + + +def _make_controller(memory_gc_rounds=1, cuda_empty_cache=False): + """Create a minimal SwarmClientController without calling __init__.""" + ctrl = SwarmClientController.__new__(SwarmClientController) + ctrl.memory_gc_rounds = memory_gc_rounds + ctrl.cuda_empty_cache = cuda_empty_cache + ctrl._aggr_round_count = 0 + ctrl.shareable_generator = Mock() + ctrl.shareable_generator.shareable_to_learnable.return_value = Mock() + ctrl.shareable_generator.learnable_to_shareable.return_value = Shareable() + ctrl.record_last_result = Mock() + ctrl._distribute_final_results = Mock() + ctrl._scatter = Mock() + ctrl.log_error = Mock() + ctrl.log_info = Mock() + ctrl.log_debug = Mock() + ctrl.update_status = Mock() + return ctrl + + +def _make_gatherer(for_round=0): + """Create a mock Gatherer.""" + gatherer = Mock() + gatherer.aggregate.return_value = Shareable() + gatherer.for_round = for_round + gatherer.fl_ctx = Mock() + return gatherer + + +def _call_end_gather(ctrl, gatherer, num_rounds_total=5): + """Call _end_gather with get_config_prop mocked.""" + + def get_config_prop(key, default=None): + if key == Constant.START_ROUND: + return 0 + if key == AppConstants.NUM_ROUNDS: + return num_rounds_total + return default + + ctrl.get_config_prop = get_config_prop + ctrl._end_gather(gatherer) + + +class TestSwarmAggregatorMemoryGC: + """Test per-round GC cadence in SwarmClientController._end_gather.""" + + def test_gc_disabled_when_memory_gc_rounds_zero(self): + """cleanup_memory is never called when memory_gc_rounds=0.""" + ctrl = _make_controller(memory_gc_rounds=0) + + with patch("nvflare.fuel.utils.memory_utils.cleanup_memory") as mock_cleanup: + _call_end_gather(ctrl, _make_gatherer(for_round=0)) + mock_cleanup.assert_not_called() + + def test_gc_fires_every_round_when_memory_gc_rounds_one(self): + """cleanup_memory is called every round when memory_gc_rounds=1 (legacy behavior).""" + ctrl = _make_controller(memory_gc_rounds=1) + + with patch("nvflare.fuel.utils.memory_utils.cleanup_memory") as mock_cleanup: + for r in range(3): + _call_end_gather(ctrl, _make_gatherer(for_round=r)) + assert mock_cleanup.call_count == 3 + + def test_gc_fires_every_n_rounds(self): + """cleanup_memory fires every N rounds when memory_gc_rounds=N.""" + ctrl = _make_controller(memory_gc_rounds=2) + + with patch("nvflare.fuel.utils.memory_utils.cleanup_memory") as mock_cleanup: + for r in range(4): + _call_end_gather(ctrl, _make_gatherer(for_round=r)) + # rounds 2 and 4 fire; rounds 1 and 3 do not + assert mock_cleanup.call_count == 2 + + def test_gc_passes_cuda_empty_cache_false(self): + """cuda_empty_cache=False is forwarded to cleanup_memory.""" + ctrl = _make_controller(memory_gc_rounds=1, cuda_empty_cache=False) + + with patch("nvflare.fuel.utils.memory_utils.cleanup_memory") as mock_cleanup: + _call_end_gather(ctrl, _make_gatherer(for_round=0)) + mock_cleanup.assert_called_once_with(cuda_empty_cache=False) + + def test_gc_passes_cuda_empty_cache_true(self): + """cuda_empty_cache=True is forwarded to cleanup_memory (swarm client has GPU).""" + ctrl = _make_controller(memory_gc_rounds=1, cuda_empty_cache=True) + + with patch("nvflare.fuel.utils.memory_utils.cleanup_memory") as mock_cleanup: + _call_end_gather(ctrl, _make_gatherer(for_round=0)) + mock_cleanup.assert_called_once_with(cuda_empty_cache=True) + + def test_gc_fires_on_final_round(self): + """cleanup_memory fires on the final round (not skipped at training end).""" + ctrl = _make_controller(memory_gc_rounds=1) + # for_round=4, num_rounds_total=5 → final round → _distribute_final_results path + with patch("nvflare.fuel.utils.memory_utils.cleanup_memory") as mock_cleanup: + _call_end_gather(ctrl, _make_gatherer(for_round=4), num_rounds_total=5) + mock_cleanup.assert_called_once() + + def test_gc_not_disabled_skips_intermediate_rounds(self): + """With memory_gc_rounds=3, only every 3rd round fires GC.""" + ctrl = _make_controller(memory_gc_rounds=3) + + with patch("nvflare.fuel.utils.memory_utils.cleanup_memory") as mock_cleanup: + for r in range(6): + _call_end_gather(ctrl, _make_gatherer(for_round=r)) + # rounds 3 and 6 fire + assert mock_cleanup.call_count == 2 diff --git a/tests/unit_test/app_common/executors/in_process_client_api_executor_test.py b/tests/unit_test/app_common/executors/in_process_client_api_executor_test.py new file mode 100644 index 0000000000..bf60506bbb --- /dev/null +++ b/tests/unit_test/app_common/executors/in_process_client_api_executor_test.py @@ -0,0 +1,90 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for InProcessClientAPIExecutor memory management parameters.""" + +import pytest + +from nvflare.app_common.executors.in_process_client_api_executor import InProcessClientAPIExecutor +from nvflare.client.config import ExchangeFormat, TransferType + + +class TestInProcessClientAPIExecutorMemory: + """Test memory management parameters in InProcessClientAPIExecutor.""" + + @pytest.fixture + def base_executor_params(self): + """Base parameters for creating executor instances.""" + return { + "task_script_path": "train.py", + "task_script_args": "--epochs 10", + "params_exchange_format": ExchangeFormat.NUMPY, + "params_transfer_type": TransferType.FULL, + } + + def test_default_memory_parameters(self, base_executor_params): + """Test that memory management parameters default to disabled.""" + executor = InProcessClientAPIExecutor(**base_executor_params) + + assert executor._memory_gc_rounds == 0 + assert executor._cuda_empty_cache is False + + def test_memory_parameters_enabled(self, base_executor_params): + """Test memory parameters can be enabled.""" + executor = InProcessClientAPIExecutor( + memory_gc_rounds=5, + cuda_empty_cache=True, + **base_executor_params, + ) + + assert executor._memory_gc_rounds == 5 + assert executor._cuda_empty_cache is True + + @pytest.mark.parametrize( + "gc_rounds,cuda_empty", + [ + (0, False), # Disabled + (1, True), # Every round with CUDA + (1, False), # Every round without CUDA + (5, True), # Every 5 rounds with CUDA + (10, False), # Every 10 rounds without CUDA + ], + ) + def test_memory_parameter_combinations(self, base_executor_params, gc_rounds, cuda_empty): + """Test various memory parameter combinations.""" + executor = InProcessClientAPIExecutor( + memory_gc_rounds=gc_rounds, + cuda_empty_cache=cuda_empty, + **base_executor_params, + ) + + assert executor._memory_gc_rounds == gc_rounds + assert executor._cuda_empty_cache == cuda_empty + + def test_memory_parameters_with_other_options(self, base_executor_params): + """Test that memory parameters work with other executor options.""" + executor = InProcessClientAPIExecutor( + task_wait_time=30.0, + result_pull_interval=1.0, + train_with_evaluation=True, + memory_gc_rounds=2, + cuda_empty_cache=True, + **base_executor_params, + ) + + assert executor._memory_gc_rounds == 2 + assert executor._cuda_empty_cache is True + assert executor._task_wait_time == 30.0 + assert executor._result_pull_interval == 1.0 + assert executor._train_with_evaluation is True diff --git a/tests/unit_test/app_common/statistics/quantile_test.py b/tests/unit_test/app_common/statistics/quantile_test.py index 888eafd181..0f3bc356b2 100644 --- a/tests/unit_test/app_common/statistics/quantile_test.py +++ b/tests/unit_test/app_common/statistics/quantile_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +import sys import numpy as np import pandas as pd @@ -29,6 +30,8 @@ except ImportError: TDIGEST_AVAILABLE = False +TDIGEST_AVAILABLE = TDIGEST_AVAILABLE and sys.platform != "darwin" + class MockDFStats(DFStatisticsCore): def __init__(self, given_median: int): diff --git a/tests/unit_test/app_opt/tf/tf_recipe_no_cuda_cache_test.py b/tests/unit_test/app_opt/tf/tf_recipe_no_cuda_cache_test.py new file mode 100644 index 0000000000..f800d38a8e --- /dev/null +++ b/tests/unit_test/app_opt/tf/tf_recipe_no_cuda_cache_test.py @@ -0,0 +1,85 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests that TF recipes do not expose cuda_empty_cache. + +TF GPU memory is managed differently from PyTorch; torch.cuda.empty_cache() +is PyTorch-only and must not be offered on TF recipes. +""" + +from unittest.mock import patch + +import pytest + +# All four TF recipes live under app_opt/tf and import TF components at the +# module level, so skip the whole file if TF is not installed. +pytest.importorskip("tensorflow", reason="TensorFlow not available") + + +@pytest.fixture +def mock_file_system(): + with ( + patch("os.path.isfile", return_value=True), + patch("os.path.isdir", return_value=True), + patch("os.path.exists", return_value=True), + ): + yield + + +class TestTFRecipesNoCudaEmptyCache: + """Verify that TF recipes reject cuda_empty_cache as a parameter.""" + + def test_tf_fedavg_rejects_cuda_empty_cache(self, mock_file_system): + from nvflare.app_opt.tf.recipes.fedavg import FedAvgRecipe + + with pytest.raises(TypeError, match="cuda_empty_cache"): + FedAvgRecipe( + min_clients=2, + num_rounds=2, + train_script="train.py", + cuda_empty_cache=True, + ) + + def test_tf_cyclic_rejects_cuda_empty_cache(self, mock_file_system): + from nvflare.app_opt.tf.recipes.cyclic import CyclicRecipe + + with pytest.raises(TypeError, match="cuda_empty_cache"): + CyclicRecipe( + min_clients=2, + num_rounds=2, + train_script="train.py", + cuda_empty_cache=True, + ) + + def test_tf_scaffold_rejects_cuda_empty_cache(self, mock_file_system): + from nvflare.app_opt.tf.recipes.scaffold import ScaffoldRecipe + + with pytest.raises(TypeError, match="cuda_empty_cache"): + ScaffoldRecipe( + min_clients=2, + num_rounds=2, + train_script="train.py", + cuda_empty_cache=True, + ) + + def test_tf_fedopt_rejects_cuda_empty_cache(self, mock_file_system): + from nvflare.app_opt.tf.recipes.fedopt import FedOptRecipe + + with pytest.raises(TypeError, match="cuda_empty_cache"): + FedOptRecipe( + min_clients=2, + num_rounds=2, + train_script="train.py", + cuda_empty_cache=True, + ) diff --git a/tests/unit_test/client/ex_process/__init__.py b/tests/unit_test/client/ex_process/__init__.py new file mode 100644 index 0000000000..4fc25d0d3c --- /dev/null +++ b/tests/unit_test/client/ex_process/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit_test/client/ex_process/memory_test.py b/tests/unit_test/client/ex_process/memory_test.py new file mode 100644 index 0000000000..cceea6ba0b --- /dev/null +++ b/tests/unit_test/client/ex_process/memory_test.py @@ -0,0 +1,126 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ExProcessClientAPI memory management.""" + +import os +import unittest +from unittest.mock import patch + + +class TestExProcessClientAPIMemory(unittest.TestCase): + """Test memory management in ExProcessClientAPI.""" + + def test_memory_settings_from_env_disabled(self): + """Test that memory settings default to disabled when env vars not set.""" + # Clear any existing env vars + env = os.environ.copy() + env.pop("NVFLARE_CLIENT_MEMORY_GC_ROUNDS", None) + env.pop("NVFLARE_CUDA_EMPTY_CACHE", None) + + with patch.dict(os.environ, env, clear=True): + # We need to test the __init__ logic without actually initializing + # the full ExProcessClientAPI (which requires config files) + gc_rounds = int(os.environ.get("NVFLARE_CLIENT_MEMORY_GC_ROUNDS", "0")) + cuda_empty = os.environ.get("NVFLARE_CUDA_EMPTY_CACHE", "").lower() == "true" + + assert gc_rounds == 0 + assert cuda_empty is False + + def test_memory_settings_from_env_enabled(self): + """Test that memory settings are read from environment variables.""" + env = { + "NVFLARE_CLIENT_MEMORY_GC_ROUNDS": "5", + "NVFLARE_CUDA_EMPTY_CACHE": "true", + } + + with patch.dict(os.environ, env, clear=False): + gc_rounds = int(os.environ.get("NVFLARE_CLIENT_MEMORY_GC_ROUNDS", "0")) + cuda_empty = os.environ.get("NVFLARE_CUDA_EMPTY_CACHE", "").lower() == "true" + + assert gc_rounds == 5 + assert cuda_empty is True + + def test_memory_settings_env_false(self): + """Test that cuda_empty_cache=false is parsed correctly.""" + env = { + "NVFLARE_CLIENT_MEMORY_GC_ROUNDS": "1", + "NVFLARE_CUDA_EMPTY_CACHE": "false", + } + + with patch.dict(os.environ, env, clear=False): + gc_rounds = int(os.environ.get("NVFLARE_CLIENT_MEMORY_GC_ROUNDS", "0")) + cuda_empty = os.environ.get("NVFLARE_CUDA_EMPTY_CACHE", "").lower() == "true" + + assert gc_rounds == 1 + assert cuda_empty is False + + def test_memory_settings_env_case_insensitive(self): + """Test that TRUE/True/true all work for cuda_empty_cache.""" + for value in ["TRUE", "True", "true", "TrUe"]: + env = {"NVFLARE_CUDA_EMPTY_CACHE": value} + with patch.dict(os.environ, env, clear=False): + cuda_empty = os.environ.get("NVFLARE_CUDA_EMPTY_CACHE", "").lower() == "true" + assert cuda_empty is True, f"Failed for value: {value}" + + +class TestMaybeCleanupMemory(unittest.TestCase): + """Test _maybe_cleanup_memory logic (extracted from ExProcessClientAPI).""" + + def test_cleanup_disabled(self): + """Test that cleanup is skipped when gc_rounds=0.""" + memory_gc_rounds = 0 + round_count = 0 + + # Logic from _maybe_cleanup_memory + if memory_gc_rounds <= 0: + should_cleanup = False + else: + round_count += 1 + should_cleanup = round_count % memory_gc_rounds == 0 + + assert should_cleanup is False + + def test_cleanup_every_round(self): + """Test cleanup every round (gc_rounds=1).""" + memory_gc_rounds = 1 + round_count = 0 + + results = [] + for _ in range(5): + round_count += 1 + should_cleanup = round_count % memory_gc_rounds == 0 + results.append(should_cleanup) + + # Should cleanup every round + assert results == [True, True, True, True, True] + + def test_cleanup_every_n_rounds(self): + """Test cleanup every N rounds.""" + memory_gc_rounds = 3 + round_count = 0 + + results = [] + for _ in range(9): + round_count += 1 + should_cleanup = round_count % memory_gc_rounds == 0 + results.append(should_cleanup) + + # Should cleanup on rounds 3, 6, 9 + expected = [False, False, True, False, False, True, False, False, True] + assert results == expected + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_test/client/in_process/api_test.py b/tests/unit_test/client/in_process/api_test.py index a873d78499..8df6586c21 100644 --- a/tests/unit_test/client/in_process/api_test.py +++ b/tests/unit_test/client/in_process/api_test.py @@ -92,4 +92,76 @@ def test_init_subscriptions2(self): TOPIC_STOP, ] + def test_memory_management_defaults(self): + """Test that memory management is disabled by default.""" + client_api = InProcessClientAPI(self.task_metadata) + assert client_api._memory_gc_rounds == 0 + assert client_api._cuda_empty_cache is False + assert client_api._round_count == 0 + + def test_configure_memory_management(self): + """Test configure_memory_management method.""" + client_api = InProcessClientAPI(self.task_metadata) + client_api.init() + + client_api.configure_memory_management(gc_rounds=5, cuda_empty_cache=True) + + assert client_api._memory_gc_rounds == 5 + assert client_api._cuda_empty_cache is True + + def test_maybe_cleanup_memory_disabled(self): + """Test that _maybe_cleanup_memory does nothing when disabled.""" + client_api = InProcessClientAPI(self.task_metadata) + client_api.init() + + # With gc_rounds=0 (disabled), should not increment round count + client_api._maybe_cleanup_memory() + assert client_api._round_count == 0 + + def test_maybe_cleanup_memory_enabled(self): + """Test that _maybe_cleanup_memory increments round count when enabled.""" + from unittest.mock import patch + + client_api = InProcessClientAPI(self.task_metadata) + client_api.init() + client_api.configure_memory_management(gc_rounds=2, cuda_empty_cache=False) + + with patch("nvflare.fuel.utils.memory_utils.cleanup_memory") as mock_cleanup: + # First round - should not trigger cleanup + client_api._maybe_cleanup_memory() + assert client_api._round_count == 1 + mock_cleanup.assert_not_called() + + # Second round - should trigger cleanup (every 2 rounds) + client_api._maybe_cleanup_memory() + assert client_api._round_count == 2 + mock_cleanup.assert_called_once_with(cuda_empty_cache=False) + + # Third round - should not trigger cleanup + mock_cleanup.reset_mock() + client_api._maybe_cleanup_memory() + assert client_api._round_count == 3 + mock_cleanup.assert_not_called() + + # Fourth round - should trigger cleanup + client_api._maybe_cleanup_memory() + assert client_api._round_count == 4 + mock_cleanup.assert_called_once() + + def test_maybe_cleanup_memory_every_round(self): + """Test cleanup every round (gc_rounds=1).""" + from unittest.mock import patch + + client_api = InProcessClientAPI(self.task_metadata) + client_api.init() + client_api.configure_memory_management(gc_rounds=1, cuda_empty_cache=True) + + with patch("nvflare.fuel.utils.memory_utils.cleanup_memory") as mock_cleanup: + client_api._maybe_cleanup_memory() + mock_cleanup.assert_called_with(cuda_empty_cache=True) + + mock_cleanup.reset_mock() + client_api._maybe_cleanup_memory() + mock_cleanup.assert_called_with(cuda_empty_cache=True) + # Add more test methods for other functionalities in the class diff --git a/tests/unit_test/fuel/utils/memory_utils_test.py b/tests/unit_test/fuel/utils/memory_utils_test.py index 63c09d7244..85cd29bb75 100644 --- a/tests/unit_test/fuel/utils/memory_utils_test.py +++ b/tests/unit_test/fuel/utils/memory_utils_test.py @@ -36,27 +36,37 @@ def test_cleanup_memory_calls_gc_collect(self): cleanup_memory() mock_gc.assert_called_once() - def test_cleanup_memory_calls_try_malloc_trim(self): - """Test that cleanup_memory calls try_malloc_trim.""" + def test_cleanup_memory_calls_try_malloc_trim_for_glibc(self): + """Test that cleanup_memory calls try_malloc_trim for glibc allocator.""" from nvflare.fuel.utils.memory_utils import cleanup_memory - with patch("nvflare.fuel.utils.memory_utils.try_malloc_trim") as mock_trim: - cleanup_memory() - mock_trim.assert_called_once() + with patch("nvflare.fuel.utils.memory_utils.get_allocator_type", return_value="glibc"): + with patch("nvflare.fuel.utils.memory_utils.try_malloc_trim") as mock_trim: + cleanup_memory() + mock_trim.assert_called_once() + + def test_cleanup_memory_skips_malloc_trim_for_jemalloc(self): + """Test that cleanup_memory skips try_malloc_trim for jemalloc allocator.""" + from nvflare.fuel.utils.memory_utils import cleanup_memory + + with patch("nvflare.fuel.utils.memory_utils.get_allocator_type", return_value="jemalloc"): + with patch("nvflare.fuel.utils.memory_utils.try_malloc_trim") as mock_trim: + cleanup_memory() + mock_trim.assert_not_called() - def test_cleanup_memory_torch_cuda_empty_cache_false(self): - """Test that cleanup_memory with torch_cuda_empty_cache=False does not call torch.""" + def test_cleanup_memory_cuda_empty_cache_false(self): + """Test that cleanup_memory with cuda_empty_cache=False does not call torch.""" from nvflare.fuel.utils.memory_utils import cleanup_memory # This should not raise and should not try to import torch - cleanup_memory(torch_cuda_empty_cache=False) + cleanup_memory(cuda_empty_cache=False) - def test_cleanup_memory_torch_cuda_empty_cache_true(self): - """Test that cleanup_memory handles torch_cuda_empty_cache=True gracefully.""" + def test_cleanup_memory_cuda_empty_cache_true(self): + """Test that cleanup_memory handles cuda_empty_cache=True gracefully.""" from nvflare.fuel.utils.memory_utils import cleanup_memory # This should not raise even if torch is not installed or CUDA unavailable - cleanup_memory(torch_cuda_empty_cache=True) + cleanup_memory(cuda_empty_cache=True) def test_get_glibc_caching(self): """Test that _get_glibc is cached (only loads once).""" @@ -74,3 +84,40 @@ def test_get_glibc_caching(self): # Check cache info cache_info = _get_glibc.cache_info() assert cache_info.hits >= 1 # Second call should be a cache hit + + def test_get_allocator_type_returns_valid_string(self): + """Test that get_allocator_type returns a valid allocator type.""" + from nvflare.fuel.utils.memory_utils import get_allocator_type + + # Clear cache first + get_allocator_type.cache_clear() + + result = get_allocator_type() + assert result in ("glibc", "jemalloc", "unknown") + + def test_get_allocator_type_caching(self): + """Test that get_allocator_type is cached (only detects once).""" + from nvflare.fuel.utils.memory_utils import get_allocator_type + + # Clear the cache first + get_allocator_type.cache_clear() + + result1 = get_allocator_type() + result2 = get_allocator_type() + + # Should return the same result + assert result1 == result2 + + # Check cache info - second call should be a cache hit + cache_info = get_allocator_type.cache_info() + assert cache_info.hits >= 1 + + def test_cleanup_memory_allocator_aware(self): + """Test that cleanup_memory adapts behavior based on allocator type.""" + from nvflare.fuel.utils.memory_utils import cleanup_memory, get_allocator_type + + # This should work regardless of allocator type + get_allocator_type.cache_clear() + cleanup_memory() + + # Verify it completed without error - allocator-specific logic handled internally diff --git a/tests/unit_test/job_config/script_runner_test.py b/tests/unit_test/job_config/script_runner_test.py index ac771231ce..eb71b17cf3 100644 --- a/tests/unit_test/job_config/script_runner_test.py +++ b/tests/unit_test/job_config/script_runner_test.py @@ -215,3 +215,89 @@ def test_custom_launcher_passed_through(self, mock_file_system, base_script_runn assert runner._launch_once is False assert runner._shutdown_timeout == 100.0 assert runner._launcher is custom_launcher + + +class TestScriptRunnerMemoryManagement: + """Test cases for ScriptRunner memory management parameters.""" + + @pytest.fixture + def base_script_runner_params(self): + """Base parameters for creating ScriptRunner instances.""" + return { + "script": "train.py", + "script_args": "--epochs 10", + "framework": FrameworkType.PYTORCH, + } + + def test_default_memory_parameters(self, base_script_runner_params): + """Test that memory management parameters default to disabled.""" + runner = ScriptRunner(**base_script_runner_params) + + assert runner._memory_gc_rounds == 0 + assert runner._cuda_empty_cache is False + + @pytest.mark.parametrize( + "memory_gc_rounds,cuda_empty_cache", + [ + (0, False), # Disabled + (1, True), # Every round with cuda cache + (5, False), # Every 5 rounds without cuda cache + (10, True), # Every 10 rounds with cuda cache + ], + ) + def test_memory_parameter_configurations(self, base_script_runner_params, memory_gc_rounds, cuda_empty_cache): + """Test various memory management configurations.""" + runner = ScriptRunner( + memory_gc_rounds=memory_gc_rounds, + cuda_empty_cache=cuda_empty_cache, + **base_script_runner_params, + ) + + assert runner._memory_gc_rounds == memory_gc_rounds + assert runner._cuda_empty_cache == cuda_empty_cache + + def test_memory_parameters_with_external_process(self, base_script_runner_params): + """Test memory parameters with external process mode.""" + runner = ScriptRunner( + launch_external_process=True, + memory_gc_rounds=3, + cuda_empty_cache=True, + **base_script_runner_params, + ) + + assert runner._memory_gc_rounds == 3 + assert runner._cuda_empty_cache is True + assert runner._launch_external_process is True + + def test_memory_parameters_with_in_process(self, base_script_runner_params): + """Test memory parameters with in-process mode.""" + runner = ScriptRunner( + launch_external_process=False, + memory_gc_rounds=2, + cuda_empty_cache=True, + **base_script_runner_params, + ) + + assert runner._memory_gc_rounds == 2 + assert runner._cuda_empty_cache is True + assert runner._launch_external_process is False + + @pytest.mark.parametrize( + "framework", + [ + FrameworkType.PYTORCH, + FrameworkType.NUMPY, + ], + ) + def test_memory_parameters_with_different_frameworks(self, framework): + """Test that memory parameters work with different frameworks.""" + runner = ScriptRunner( + script="train.py", + memory_gc_rounds=1, + cuda_empty_cache=True, + framework=framework, + ) + + assert runner._memory_gc_rounds == 1 + assert runner._cuda_empty_cache is True + assert runner._framework == framework diff --git a/tests/unit_test/recipe/swarm_recipe_test.py b/tests/unit_test/recipe/swarm_recipe_test.py index 432e2d9da5..1f992c5e16 100644 --- a/tests/unit_test/recipe/swarm_recipe_test.py +++ b/tests/unit_test/recipe/swarm_recipe_test.py @@ -177,3 +177,68 @@ def test_train_args_valid_keys_accepted(self, mock_file_system, simple_pt_model) ) assert recipe.job is not None + + +class TestSimpleSwarmLearningRecipeMemoryGC: + """Test memory GC parameters on SimpleSwarmLearningRecipe.""" + + def test_default_memory_gc_rounds_is_one(self): + """Default memory_gc_rounds=1 for backward compatibility with legacy GC behavior.""" + import inspect + + from nvflare.app_opt.pt.recipes.swarm import SimpleSwarmLearningRecipe + + sig = inspect.signature(SimpleSwarmLearningRecipe.__init__) + assert sig.parameters["memory_gc_rounds"].default == 1 + + def test_old_param_name_rejected(self, mock_file_system, simple_pt_model): + """client_memory_gc_rounds (old name) is no longer accepted.""" + from nvflare.app_opt.pt.recipes.swarm import SimpleSwarmLearningRecipe + + with pytest.raises(TypeError, match="client_memory_gc_rounds"): + SimpleSwarmLearningRecipe( + name="test_swarm", + model=simple_pt_model, + num_rounds=5, + train_script="train.py", + client_memory_gc_rounds=2, + ) + + def test_memory_gc_rounds_custom_accepted(self, mock_file_system, simple_pt_model): + """Custom memory_gc_rounds is accepted.""" + from nvflare.app_opt.pt.recipes.swarm import SimpleSwarmLearningRecipe + + recipe = SimpleSwarmLearningRecipe( + name="test_swarm", + model=simple_pt_model, + num_rounds=5, + train_script="train.py", + memory_gc_rounds=2, + ) + assert recipe.job is not None + + def test_memory_gc_disabled_accepted(self, mock_file_system, simple_pt_model): + """memory_gc_rounds=0 disables GC.""" + from nvflare.app_opt.pt.recipes.swarm import SimpleSwarmLearningRecipe + + recipe = SimpleSwarmLearningRecipe( + name="test_swarm", + model=simple_pt_model, + num_rounds=5, + train_script="train.py", + memory_gc_rounds=0, + ) + assert recipe.job is not None + + def test_cuda_empty_cache_accepted(self, mock_file_system, simple_pt_model): + """cuda_empty_cache=True is accepted and wired through.""" + from nvflare.app_opt.pt.recipes.swarm import SimpleSwarmLearningRecipe + + recipe = SimpleSwarmLearningRecipe( + name="test_swarm", + model=simple_pt_model, + num_rounds=5, + train_script="train.py", + cuda_empty_cache=True, + ) + assert recipe.job is not None