Skip to content

Latest commit

 

History

History
362 lines (249 loc) · 11.1 KB

File metadata and controls

362 lines (249 loc) · 11.1 KB

Memory Management

This guide describes memory management techniques for long-running federated learning jobs using Python, PyTorch, and glibc/jemalloc.

Federated learning jobs can run for hours or days. Without proper memory management, RSS (Resident Set Size) can grow continuously due to:

  • Python garbage collection delays
  • glibc memory arena fragmentation
  • PyTorch CUDA cache retention

NVFlare provides utilities and configuration options to manage memory effectively on both server and client sides. The framework automatically detects the memory allocator in use (glibc or jemalloc) and adapts its cleanup strategy accordingly.

NVFlare supports two memory allocators:

glibc (default on most Linux)
Uses malloc_trim() to release free heap pages to the OS. Requires MALLOC_ARENA_MAX for optimal memory behavior.
jemalloc (recommended for PyTorch)
Uses auto-decay for memory management. Configure via MALLOC_CONF. No malloc_trim() calls needed (jemalloc handles this automatically).

NVFlare automatically detects which allocator is in use at runtime.

Not all memory management features work on all platforms. The table below summarizes compatibility:

Feature Linux/glibc Linux/musl macOS
gc.collect()
MALLOC_ARENA_MAX
malloc_trim()
torch.cuda.empty_cache

Notes:

  • Linux/glibc: Standard Linux distributions (Ubuntu, RHEL, Debian, etc.)
  • Linux/musl: Alpine Linux and other musl-based distributions
  • macOS: malloc_trim() is silently skipped (safe no-op)

Warning

For maximum memory efficiency, use Linux with glibc. Alpine Linux (musl) and macOS/Windows will still benefit from gc.collect() but cannot release fragmented heap memory back to the OS.

Set these environment variables before starting NVFlare processes:

export MALLOC_ARENA_MAX=2

Why: Clients typically have limited CPU memory. Setting MALLOC_ARENA_MAX=2 prevents arena explosion and reduces memory fragmentation.

export MALLOC_ARENA_MAX=4

Why: Servers are CPU memory heavy (4-7× model size) with multi-threaded networking. MALLOC_ARENA_MAX=4 balances throughput vs memory. Use 8 for high parallelism.

The FedAvg controller supports automatic memory cleanup via the server_memory_gc_rounds parameter.

from nvflare.recipe.fedavg import FedAvgRecipe

recipe = FedAvgRecipe(
    name="my_job",
    min_clients=4,
    num_rounds=100,
    train_script="client.py",
    server_memory_gc_rounds=5,  # Cleanup every 5 rounds
)

Values:

  • 0 = Disabled (default for BaseFedAvg-based controllers)
  • 1 = Cleanup every round (default for legacy controllers like ScatterAndGather)
  • 5 = Cleanup every 5 rounds (recommended for server)

When enabled, at the end of every N rounds:

  1. Runs Python garbage collection (gc.collect())
  2. Returns free heap pages to OS (malloc_trim(), Linux/glibc only)

Memory cleanup has minimal overhead in typical federated learning workloads:

Operation Typical Duration Notes
gc.collect() 10-500 ms Depends on Python object count
malloc_trim() < 1 ms Very fast (page table ops)

Overhead analysis:

  • Training round duration: Typically 30 seconds to 10+ minutes
  • Cleanup duration: 10-500 ms total
  • Overhead per round: Usually < 1%

With server_memory_gc_rounds=5:

  • Cleanup runs once every 5 rounds
  • Total overhead: < 0.2% of training time

Recommendation: The default server_memory_gc_rounds=5 provides good memory management with negligible performance impact. Only disable (=0) if you've measured and confirmed RSS is stable without cleanup.

The FedAvg recipe and ScriptRunner support automatic memory cleanup on clients via client_memory_gc_rounds and cuda_empty_cache parameters.

from nvflare.recipe.fedavg import FedAvgRecipe

recipe = FedAvgRecipe(
    name="my_job",
    min_clients=4,
    num_rounds=100,
    train_script="client.py",

    # Server-side cleanup
    server_memory_gc_rounds=5,

    # Client-side cleanup
    client_memory_gc_rounds=1,   # Cleanup every round
    cuda_empty_cache=True, # Clear GPU cache
)

Swarm Learning uses memory_gc_rounds (not memory_gc_counts) and cuda_empty_cache on SimpleSwarmLearningRecipe:

from nvflare.app_opt.pt.recipes.swarm import SimpleSwarmLearningRecipe

recipe = SimpleSwarmLearningRecipe(
    name="swarm_job",
    model=MyModel(),
    num_rounds=10,
    train_script="train.py",
    memory_gc_rounds=1,   # Cleanup every round on trainer and aggregator roles
    cuda_empty_cache=True,
)

Note

memory_gc_rounds and cuda_empty_cache are top-level Swarm recipe arguments. Do not pass them inside train_args (they are reserved keys).

Parameters:

  • client_memory_gc_rounds: Run cleanup every N rounds on client (0 = disabled)
  • cuda_empty_cache: If True, call torch.cuda.empty_cache() on cleanup
  • memory_gc_rounds (Swarm): Run cleanup every N rounds (0 = disabled)

When enabled, after each flare.send() on the client:

  1. Runs Python garbage collection (gc.collect())
  2. For glibc: Returns free heap pages to OS (malloc_trim())
  3. For jemalloc: Relies on auto-decay (no manual action needed)
  4. Optionally clears PyTorch CUDA cache

Note: The cleanup is transparent to the user's training script. No code changes are required in train.py.

For external process execution (launch_external_process=True), memory settings are passed via environment variables:

  • NVFLARE_CLIENT_MEMORY_GC_ROUNDS: Cleanup interval
  • NVFLARE_CUDA_EMPTY_CACHE: GPU cache cleanup (true/false)
Role server_memory_gc_rounds client_memory_gc_rounds MALLOC_ARENA_MAX cuda_empty_cache
Server 5 N/A 4 N/A
Client N/A 1 2 True (for GPU)

For PyTorch workloads, jemalloc is recommended over glibc malloc. NVFlare's startup scripts automatically detect and use jemalloc if available.

The generated sub_start.sh script includes jemalloc auto-detection:

# Auto-detects jemalloc at standard locations
for JEMALLOC in /usr/lib/x86_64-linux-gnu/libjemalloc.so.2 \
                /usr/lib64/libjemalloc.so.2 \
                /usr/local/lib/libjemalloc.so; do
    if [ -f "$JEMALLOC" ]; then
        export LD_PRELOAD="${LD_PRELOAD:+$LD_PRELOAD:}$JEMALLOC"
        export MALLOC_CONF="${MALLOC_CONF:-dirty_decay_ms:5000,muzzy_decay_ms:5000}"
        break
    fi
done
# Ubuntu/Debian
apt-get install libjemalloc2

# RHEL/CentOS
yum install jemalloc
from nvflare.fuel.utils.memory_utils import cleanup_memory

cleanup_memory(cuda_empty_cache=True)

Signature: cleanup_memory(cuda_empty_cache: bool = False) -> None

Performs allocator-aware memory cleanup:

  1. Runs gc.collect()
  2. For glibc: Calls malloc_trim(0)
  3. For jemalloc: Relies on auto-decay (no action needed)
  4. Optionally calls torch.cuda.empty_cache()
from nvflare.fuel.utils.memory_utils import get_allocator_type

allocator = get_allocator_type()  # "glibc", "jemalloc", or "unknown"

Signature: get_allocator_type() -> str

Detects which memory allocator is in use at runtime. Result is cached.

from nvflare.fuel.utils.memory_utils import try_malloc_trim

result = try_malloc_trim()

Signature: try_malloc_trim() -> Optional[int]

Low-level function to return free heap pages to OS.

Returns:

  • 1 if memory was released
  • 0 if no memory to release
  • None if not available (non-Linux or non-glibc)
  1. Check MALLOC_ARENA_MAX is set
  2. Enable server_memory_gc_rounds=5
  3. Consider using jemalloc (LD_PRELOAD)
  4. Monitor with top or htop
  1. Check MALLOC_ARENA_MAX=2 is set
  2. Enable client_memory_gc_rounds=1
  3. Enable cuda_empty_cache=True for GPU
  4. Consider using jemalloc
  1. Reduce batch size
  2. Enable memory cleanup every round (client_memory_gc_rounds=1 or server_memory_gc_rounds=1)
  3. Check for memory leaks in training code
  4. Use jemalloc with appropriate decay settings