Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 171 additions & 15 deletions docs/programming_guide/memory_management.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Memory Management
###################

This guide describes memory management techniques for long-running federated learning jobs
using Python, PyTorch, and glibc.
using Python, PyTorch, and glibc/jemalloc.

.. contents:: Table of Contents
:local:
Expand All @@ -21,7 +21,24 @@ RSS (Resident Set Size) can grow continuously due to:
- glibc memory arena fragmentation
- PyTorch CUDA cache retention

NVFlare provides utilities and configuration options to manage memory effectively.
NVFlare provides utilities and configuration options to manage memory effectively on
both server and client sides. The framework automatically detects the memory allocator
in use (glibc or jemalloc) and adapts its cleanup strategy accordingly.

Allocator Support
=================

NVFlare supports two memory allocators:

**glibc (default on most Linux)**
Uses ``malloc_trim()`` to release free heap pages to the OS.
Requires ``MALLOC_ARENA_MAX`` for optimal memory behavior.

**jemalloc (recommended for PyTorch)**
Uses auto-decay for memory management. Configure via ``MALLOC_CONF``.
No ``malloc_trim()`` calls needed (jemalloc handles this automatically).

NVFlare automatically detects which allocator is in use at runtime.

Platform Compatibility
======================
Expand Down Expand Up @@ -139,14 +156,130 @@ Memory cleanup has minimal overhead in typical federated learning workloads:
management with negligible performance impact. Only disable (``=0``) if you've
measured and confirmed RSS is stable without cleanup.

Client-Side Memory Cleanup
==========================

The FedAvg recipe and ScriptRunner support automatic memory cleanup on clients via
``client_memory_gc_rounds`` and ``cuda_empty_cache`` parameters.

Configuration
-------------

.. code-block:: python

from nvflare.recipe.fedavg import FedAvgRecipe

recipe = FedAvgRecipe(
name="my_job",
min_clients=4,
num_rounds=100,
train_script="client.py",

# Server-side cleanup
server_memory_gc_rounds=5,

# Client-side cleanup
client_memory_gc_rounds=1, # Cleanup every round
cuda_empty_cache=True, # Clear GPU cache
)

Swarm Learning Configuration
----------------------------

Swarm Learning uses ``memory_gc_rounds`` (not ``memory_gc_counts``) and
``cuda_empty_cache`` on ``SimpleSwarmLearningRecipe``:

.. code-block:: python

from nvflare.app_opt.pt.recipes.swarm import SimpleSwarmLearningRecipe

recipe = SimpleSwarmLearningRecipe(
name="swarm_job",
model=MyModel(),
num_rounds=10,
train_script="train.py",
memory_gc_rounds=1, # Cleanup every round on trainer and aggregator roles
cuda_empty_cache=True,
)

.. note::

``memory_gc_rounds`` and ``cuda_empty_cache`` are top-level Swarm recipe arguments.
Do not pass them inside ``train_args`` (they are reserved keys).

**Parameters:**

- ``client_memory_gc_rounds``: Run cleanup every N rounds on client (0 = disabled)
- ``cuda_empty_cache``: If True, call ``torch.cuda.empty_cache()`` on cleanup
- ``memory_gc_rounds`` (Swarm): Run cleanup every N rounds (0 = disabled)

What It Does
------------

When enabled, after each ``flare.send()`` on the client:

1. Runs Python garbage collection (``gc.collect()``)
2. For glibc: Returns free heap pages to OS (``malloc_trim()``)
3. For jemalloc: Relies on auto-decay (no manual action needed)
4. Optionally clears PyTorch CUDA cache

**Note:** The cleanup is transparent to the user's training script. No code changes
are required in ``train.py``.

External Process Support
------------------------

For external process execution (``launch_external_process=True``), memory settings
are passed via environment variables:

- ``NVFLARE_CLIENT_MEMORY_GC_ROUNDS``: Cleanup interval
- ``NVFLARE_CUDA_EMPTY_CACHE``: GPU cache cleanup (``true``/``false``)

Recommended Settings
====================

+--------+-------------------------------+----------------------+
| Role | ``server_memory_gc_rounds`` | ``MALLOC_ARENA_MAX`` |
+========+===============================+======================+
| Server | 5 | 4 |
+--------+-------------------------------+----------------------+
+--------+-----------------------------+-----------------------------+----------------------+----------------------+
| Role | ``server_memory_gc_rounds`` | ``client_memory_gc_rounds`` | ``MALLOC_ARENA_MAX`` | ``cuda_empty_cache`` |
+========+=============================+=============================+======================+======================+
| Server | 5 | N/A | 4 | N/A |
+--------+-----------------------------+-----------------------------+----------------------+----------------------+
| Client | N/A | 1 | 2 | True (for GPU) |
+--------+-----------------------------+-----------------------------+----------------------+----------------------+

Using jemalloc
==============

For PyTorch workloads, jemalloc is recommended over glibc malloc. NVFlare's startup
scripts automatically detect and use jemalloc if available.

Startup Script
--------------

The generated ``sub_start.sh`` script includes jemalloc auto-detection:

.. code-block:: bash

# Auto-detects jemalloc at standard locations
for JEMALLOC in /usr/lib/x86_64-linux-gnu/libjemalloc.so.2 \
/usr/lib64/libjemalloc.so.2 \
/usr/local/lib/libjemalloc.so; do
if [ -f "$JEMALLOC" ]; then
export LD_PRELOAD="${LD_PRELOAD:+$LD_PRELOAD:}$JEMALLOC"
export MALLOC_CONF="${MALLOC_CONF:-dirty_decay_ms:5000,muzzy_decay_ms:5000}"
break
fi
done

Installing jemalloc
-------------------

.. code-block:: bash

# Ubuntu/Debian
apt-get install libjemalloc2

# RHEL/CentOS
yum install jemalloc

API Reference
=============
Expand All @@ -158,15 +291,29 @@ cleanup_memory

from nvflare.fuel.utils.memory_utils import cleanup_memory

cleanup_memory(torch_cuda_empty_cache=True)
cleanup_memory(cuda_empty_cache=True)

**Signature:** ``cleanup_memory(torch_cuda_empty_cache: bool = False) -> None``
**Signature:** ``cleanup_memory(cuda_empty_cache: bool = False) -> None``

Performs memory cleanup:
Performs allocator-aware memory cleanup:

1. Runs ``gc.collect()``
2. Calls ``malloc_trim(0)`` (Linux/glibc only, safe no-op elsewhere)
3. Optionally calls ``torch.cuda.empty_cache()``
2. For glibc: Calls ``malloc_trim(0)``
3. For jemalloc: Relies on auto-decay (no action needed)
4. Optionally calls ``torch.cuda.empty_cache()``

get_allocator_type
------------------

.. code-block:: python

from nvflare.fuel.utils.memory_utils import get_allocator_type

allocator = get_allocator_type() # "glibc", "jemalloc", or "unknown"

**Signature:** ``get_allocator_type() -> str``

Detects which memory allocator is in use at runtime. Result is cached.

try_malloc_trim
---------------
Expand Down Expand Up @@ -195,12 +342,21 @@ High RSS on Server

1. Check ``MALLOC_ARENA_MAX`` is set
2. Enable ``server_memory_gc_rounds=5``
3. Monitor with ``top`` or ``htop``
3. Consider using jemalloc (LD_PRELOAD)
4. Monitor with ``top`` or ``htop``

High RSS on Client
------------------

1. Check ``MALLOC_ARENA_MAX=2`` is set
2. Enable ``client_memory_gc_rounds=1``
3. Enable ``cuda_empty_cache=True`` for GPU
4. Consider using jemalloc

OOM Errors
----------

1. Reduce batch size
2. Enable memory cleanup every round (``server_memory_gc_rounds=1``)
2. Enable memory cleanup every round (``client_memory_gc_rounds=1`` or ``server_memory_gc_rounds=1``)
3. Check for memory leaks in training code

4. Use jemalloc with appropriate decay settings
6 changes: 6 additions & 0 deletions nvflare/app_common/ccwf/ccwf_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def __init__(
request_to_submit_result_interval: float = 1.0,
max_concurrent_submissions: int = 1,
enable_tensor_disk_offload: bool = False,
memory_gc_rounds: int = 1,
cuda_empty_cache: bool = False,
):
# the executor could be a wrapper object that adds real Executor when added to job!
validate_object_for_job("executor", executor, Executor)
Expand Down Expand Up @@ -115,6 +117,8 @@ def __init__(
self.request_to_submit_result_interval = request_to_submit_result_interval
self.max_concurrent_submissions = max_concurrent_submissions
self.enable_tensor_disk_offload = enable_tensor_disk_offload
self.memory_gc_rounds = memory_gc_rounds
self.cuda_empty_cache = cuda_empty_cache


class CyclicServerConfig:
Expand Down Expand Up @@ -275,6 +279,8 @@ def add_swarm(
request_to_submit_result_interval=client_config.request_to_submit_result_interval,
max_concurrent_submissions=client_config.max_concurrent_submissions,
enable_tensor_disk_offload=client_config.enable_tensor_disk_offload,
memory_gc_rounds=client_config.memory_gc_rounds,
cuda_empty_cache=client_config.cuda_empty_cache,
)
self.to_clients(client_controller, tasks=["swarm_*"])
if not self.executor:
Expand Down
46 changes: 30 additions & 16 deletions nvflare/app_common/ccwf/swarm_client_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import gc
import random
import threading
import time
Expand Down Expand Up @@ -110,9 +109,6 @@ def gather(self, client_name: str, result: Shareable, fl_ctx: FLContext) -> Shar
if client_status:
client_status.busy = False

# force garbage collection after each gather
gc.collect()

def can_accept_submission(self, client_name: str, result: Shareable, fl_ctx: FLContext) -> str:
with self.perm_lock:
result_round = result.get_header(AppConstants.CURRENT_ROUND)
Expand Down Expand Up @@ -295,6 +291,8 @@ def __init__(
request_to_submit_result_interval: float = 1.0,
max_concurrent_submissions: int = 1,
enable_tensor_disk_offload: bool = False,
memory_gc_rounds: int = 1,
cuda_empty_cache: bool = False,
):
"""
Constructor of a ClientSideController object.
Expand Down Expand Up @@ -325,6 +323,12 @@ def __init__(
max_concurrent_submissions: max number of concurrent submissions allowed on the aggregation client.
enable_tensor_disk_offload: download tensors to disk during FOBS streaming instead of into memory.
Reduces peak memory during aggregation. Aggregators must handle lazy refs.
memory_gc_rounds: run gc.collect() + malloc_trim on the aggregator every N FL rounds.
Defaults to 1 (every round) to match legacy behavior where gc.collect() was called
unconditionally after each trainer submission. Set to 0 to disable.
cuda_empty_cache: also call torch.cuda.empty_cache() during aggregator-side cleanup.
In swarm learning the aggregator runs on the same client as the trainer, so GPU
memory may be relevant. Defaults to False.

Note that if the max_concurrent_submissions is set to 1, it practically means that all training results
will be submitted to the aggregation client sequentially. This lowers the resource pressure on
Expand Down Expand Up @@ -386,6 +390,9 @@ def __init__(
self.last_aggr_round_done = -1
self.enable_tensor_disk_offload = enable_tensor_disk_offload
self._previous_enable_tensor_disk_offload = None
self.memory_gc_rounds = memory_gc_rounds
self.cuda_empty_cache = cuda_empty_cache
self._aggr_round_count = 0

def process_config(self, fl_ctx: FLContext):
all_clients = self.get_config_prop(Constant.CLIENTS)
Expand Down Expand Up @@ -585,22 +592,29 @@ def _end_gather(self, gatherer: Gatherer):

# determine the best global result
self._distribute_final_results(aggr_result, fl_ctx)
return

# continue next round
next_round_data = self.shareable_generator.learnable_to_shareable(global_weights, fl_ctx)
assert isinstance(next_round_data, Shareable)
else:
# continue next round
next_round_data = self.shareable_generator.learnable_to_shareable(global_weights, fl_ctx)
assert isinstance(next_round_data, Shareable)

best_round = aggr_result.get_header(Constant.ROUND)
best_metric = aggr_result.get_header(Constant.METRIC)
best_client = aggr_result.get_header(Constant.CLIENT)
best_round = aggr_result.get_header(Constant.ROUND)
best_metric = aggr_result.get_header(Constant.METRIC)
best_client = aggr_result.get_header(Constant.CLIENT)

if best_client:
next_round_data.set_header(Constant.ROUND, best_round)
next_round_data.set_header(Constant.CLIENT, best_client)
next_round_data.set_header(Constant.METRIC, best_metric)
if best_client:
next_round_data.set_header(Constant.ROUND, best_round)
next_round_data.set_header(Constant.CLIENT, best_client)
next_round_data.set_header(Constant.METRIC, best_metric)

self._scatter(next_round_data, gatherer.for_round + 1, gatherer.fl_ctx)

if self.memory_gc_rounds > 0:
self._aggr_round_count += 1
if self._aggr_round_count % self.memory_gc_rounds == 0:
from nvflare.fuel.utils.memory_utils import cleanup_memory

self._scatter(next_round_data, gatherer.for_round + 1, gatherer.fl_ctx)
cleanup_memory(cuda_empty_cache=self.cuda_empty_cache)

def _ask_to_share_best_result(self, client: str, metric, fl_ctx: FLContext):
# other client has best model - ask it to distribute its result
Expand Down
6 changes: 6 additions & 0 deletions nvflare/app_common/executors/client_api_launcher_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def __init__(
params_transfer_type: str = TransferType.FULL,
config_file_name: str = CLIENT_API_CONFIG,
server_expected_format: str = ExchangeFormat.NUMPY,
memory_gc_rounds: int = 0,
cuda_empty_cache: bool = False,
) -> None:
"""Initializes the ClientAPILauncherExecutor.

Expand Down Expand Up @@ -107,6 +109,8 @@ def __init__(
self._params_exchange_format = params_exchange_format
self._params_transfer_type = params_transfer_type
self._config_file_name = config_file_name
self._memory_gc_rounds = memory_gc_rounds
self._cuda_empty_cache = cuda_empty_cache

def initialize(self, fl_ctx: FLContext) -> None:
self.prepare_config_for_launch(fl_ctx)
Expand Down Expand Up @@ -141,6 +145,8 @@ def prepare_config_for_launch(self, fl_ctx: FLContext):
ConfigKey.ARG: pipe_export_args,
},
ConfigKey.HEARTBEAT_TIMEOUT: self.heartbeat_timeout,
ConfigKey.MEMORY_GC_ROUNDS: self._memory_gc_rounds,
ConfigKey.CUDA_EMPTY_CACHE: self._cuda_empty_cache,
}

config_data = {
Expand Down
Loading
Loading