diff --git a/src/bindings/python/src/openvino/frontend/jax/jaxpr_decoder.py b/src/bindings/python/src/openvino/frontend/jax/jaxpr_decoder.py index 057e30eaac4092..0f9352af206f9f 100644 --- a/src/bindings/python/src/openvino/frontend/jax/jaxpr_decoder.py +++ b/src/bindings/python/src/openvino/frontend/jax/jaxpr_decoder.py @@ -14,6 +14,7 @@ else: import jax.extend as jex +import weakref from openvino.frontend.jax.py_jax_frontend import _FrontEndJaxDecoder as Decoder from openvino import PartialShape, Type as OVType, OVAny from openvino.frontend.jax.utils import jax_array_to_ov_const, get_ov_type_for_value, \ @@ -81,8 +82,8 @@ def __init__(self, jaxpr, name=None, literals=None): if converted is not None: self.params.update(converted) - # TODO: this implementation may lead to memory increasing. Any better solution? - self.m_decoders = [] + + self.m_decoders = weakref.WeakSet() def inputs(self) -> list[int]: if isinstance(self.jaxpr, jex.core.JaxprEqn): @@ -147,7 +148,7 @@ def visit_subgraph(self, node_visitor) -> None: if isinstance(self.jaxpr, jex.core.JaxprEqn): return for _, decoder in self.params.items(): - self.m_decoders.append(decoder) + self.m_decoders.add(decoder) node_visitor(decoder) for idx, node in enumerate(self.jaxpr.constvars): decoder = self.convert_literal_to_constant_node( @@ -155,7 +156,7 @@ def visit_subgraph(self, node_visitor) -> None: name=self.name + "/" + f"const({id(node)})", output_id=id(node) ) - self.m_decoders.append(decoder) + self.m_decoders.add(decoder) node_visitor(decoder) # Visit every `JaxEqn` in the jaxpr, see https://github.com/google/jax/blob/jaxlib-v0.4.29/jax/_src/core.py#L285 for node in self.jaxpr.eqns: @@ -166,7 +167,7 @@ def visit_subgraph(self, node_visitor) -> None: literal_decoders.append(literal_decoder) node_visitor(literal_decoder) decoder = JaxprPythonDecoder(node, name=self.name + "/" + node.primitive.name, literals=literal_decoders) - self.m_decoders.append(decoder) + self.m_decoders.add(decoder) node_visitor(decoder) def get_op_type(self) -> str: