Skip to content

[ Jax FE] Todo Fix memory leak #35643

Open
andersendsa wants to merge 3 commits intoopenvinotoolkit:masterfrom
andersendsa:memory_leak
Open

[ Jax FE] Todo Fix memory leak #35643
andersendsa wants to merge 3 commits intoopenvinotoolkit:masterfrom
andersendsa:memory_leak

Conversation

@andersendsa
Copy link
Copy Markdown

@andersendsa andersendsa commented May 2, 2026

Details:

The JAX frontend JaxprPythonDecoder was storing all decoders created during visit_subgraph in a standard Python list (self.m_decoders). This created strong references, causing memory usage to increase linearly with model size, as noted by an existing TODO.

Changes:

  • Replaced the self.m_decoders = [] list with weakref.WeakSet() to prevent unbound memory growth

  • Updated all instances of m_decoders.append() to m_decoders.add().

Closes : jax-ml/jax#37352

AI Assistance:

  • *AI assistance used: no

@andersendsa andersendsa requested a review from a team as a code owner May 2, 2026 06:23
@github-actions github-actions Bot added category: Python API OpenVINO Python bindings category: JAX FE OpenVINO JAX FrontEnd labels May 2, 2026
@sys-openvino-ci sys-openvino-ci added the ExternalPR External contributor label May 2, 2026
@andersendsa
Copy link
Copy Markdown
Author

Hi @mlukasze could you please tell if this pr needs any changes or if it is good to merge

@mlukasze mlukasze requested review from Copilot and mvafin and removed request for a team May 4, 2026 06:06
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR addresses unbounded memory growth in the Python-side JAX frontend decoder traversal by changing the decoder keep-alive container from a strong-reference list to a weak-reference collection.

Changes:

  • Replaced self.m_decoders = [] with weakref.WeakSet() in JaxprPythonDecoder.
  • Updated all m_decoders.append(...) call sites to m_decoders.add(...).

Comment on lines +85 to +86

self.m_decoders = weakref.WeakSet()
@andersendsa andersendsa changed the title [TODO] Fix memory leak [ Jax FE] Todo Fix memory leak May 4, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

category: JAX FE OpenVINO JAX FrontEnd category: Python API OpenVINO Python bindings ExternalPR External contributor

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants