Skip to content

Commit de9c154

Browse files
committed
fix flaky garbage collection test by using tracemalloc instead of RSS
The test measured memory cleanup via process RSS (psutil), but most OS memory allocators do not return freed pages to the OS, so RSS stays flat after deallocation. JAX's XLA allocator made this even less reliable. Switch to tracemalloc which accurately tracks Python-level allocations regardless of OS allocator behavior.
1 parent d47c8a6 commit de9c154

1 file changed

Lines changed: 36 additions & 46 deletions

File tree

tests/test_memory_validation_disk_backed.py

Lines changed: 36 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -413,53 +413,43 @@ def __call__(self, X):
413413
logger.info(f" File: {result.filename}")
414414
logger.info(f" Size on disk: {os.path.getsize(result.filename) / 1024 / 1024:.2f} MB")
415415

416-
@pytest.mark.skipif(not PSUTIL_AVAILABLE, reason="Requires psutil for memory monitoring")
417416
def test_garbage_collection_effectiveness(self):
418-
"""Test that garbage collection is effective in cleaning up JAX arrays."""
419-
pytest.importorskip("jax")
420-
421-
import jax.numpy as jnp
422-
423-
initial_memory = get_memory_usage()
424-
logger.info(f"Initial memory: {initial_memory:.2f} MB")
425-
426-
# Create large JAX arrays
427-
large_arrays = []
428-
for i in range(10):
429-
arr = jnp.ones((1000, 1000)) # ~8MB per array
430-
large_arrays.append(arr)
431-
432-
after_allocation = get_memory_usage()
433-
logger.info(f"After allocating JAX arrays: {after_allocation:.2f} MB")
434-
435-
# Convert to numpy and delete JAX references
436-
numpy_arrays = []
437-
for arr in large_arrays:
438-
numpy_arrays.append(np.asarray(arr))
439-
440-
# Delete JAX arrays
441-
del large_arrays
442-
gc.collect()
443-
444-
after_conversion = get_memory_usage()
445-
logger.info(f"After JAX to numpy conversion: {after_conversion:.2f} MB")
446-
447-
# Delete numpy arrays
448-
del numpy_arrays
449-
gc.collect()
450-
451-
final_memory = get_memory_usage()
452-
logger.info(f"Final memory after cleanup: {final_memory:.2f} MB")
453-
454-
# Memory should return close to initial levels
455-
memory_cleanup_efficiency = (after_allocation - final_memory) / (after_allocation - initial_memory)
456-
logger.info(f"Memory cleanup efficiency: {memory_cleanup_efficiency:.1%}")
457-
458-
# We should recover at least some of the allocated memory
459-
# Lower threshold for testing environments where OS memory management may be different
460-
assert memory_cleanup_efficiency > 0.2, (
461-
f"Poor memory cleanup efficiency: {memory_cleanup_efficiency:.1%}"
462-
)
417+
"""Test that garbage collection reclaims memory from large numpy arrays.
418+
419+
Uses tracemalloc (Python-level allocation tracking) instead of RSS
420+
because most OS memory allocators do not return freed pages to the OS,
421+
so process RSS stays flat even after successful deallocation.
422+
"""
423+
import tracemalloc
424+
425+
tracemalloc.start()
426+
try:
427+
initial_memory = tracemalloc.get_traced_memory()[0] / 1024 / 1024
428+
logger.info(f"Initial traced memory: {initial_memory:.2f} MB")
429+
430+
# Create large numpy arrays (~80 MB total)
431+
large_arrays = [np.ones((1000, 1000), dtype=np.float64) for _ in range(10)]
432+
433+
after_allocation = tracemalloc.get_traced_memory()[0] / 1024 / 1024
434+
memory_increase = after_allocation - initial_memory
435+
logger.info(f"After allocating numpy arrays: {after_allocation:.2f} MB (+{memory_increase:.1f} MB)")
436+
437+
# Delete arrays and force garbage collection
438+
del large_arrays
439+
gc.collect()
440+
441+
final_memory = tracemalloc.get_traced_memory()[0] / 1024 / 1024
442+
logger.info(f"Final traced memory after cleanup: {final_memory:.2f} MB")
443+
444+
memory_cleanup_efficiency = (after_allocation - final_memory) / memory_increase
445+
logger.info(f"Memory cleanup efficiency: {memory_cleanup_efficiency:.1%}")
446+
447+
# We should recover at least most of the allocated memory
448+
assert memory_cleanup_efficiency > 0.2, (
449+
f"Poor memory cleanup efficiency: {memory_cleanup_efficiency:.1%}"
450+
)
451+
finally:
452+
tracemalloc.stop()
463453

464454
logger.info("Garbage collection effectiveness test passed")
465455

0 commit comments

Comments
 (0)