Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/bindings/python/src/openvino/frontend/jax/jaxpr_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
else:
import jax.extend as jex

import weakref
from openvino.frontend.jax.py_jax_frontend import _FrontEndJaxDecoder as Decoder
from openvino import PartialShape, Type as OVType, OVAny
from openvino.frontend.jax.utils import jax_array_to_ov_const, get_ov_type_for_value, \
Expand Down Expand Up @@ -81,8 +82,8 @@ def __init__(self, jaxpr, name=None, literals=None):
if converted is not None:
self.params.update(converted)

# TODO: this implementation may lead to memory increasing. Any better solution?
self.m_decoders = []

self.m_decoders = weakref.WeakSet()
Comment on lines +85 to +86

def inputs(self) -> list[int]:
if isinstance(self.jaxpr, jex.core.JaxprEqn):
Expand Down Expand Up @@ -147,15 +148,15 @@ def visit_subgraph(self, node_visitor) -> None:
if isinstance(self.jaxpr, jex.core.JaxprEqn):
return
for _, decoder in self.params.items():
self.m_decoders.append(decoder)
self.m_decoders.add(decoder)
node_visitor(decoder)
for idx, node in enumerate(self.jaxpr.constvars):
decoder = self.convert_literal_to_constant_node(
literal=self.literals[idx],
name=self.name + "/" + f"const({id(node)})",
output_id=id(node)
)
self.m_decoders.append(decoder)
self.m_decoders.add(decoder)
node_visitor(decoder)
# Visit every `JaxEqn` in the jaxpr, see https://github.com/google/jax/blob/jaxlib-v0.4.29/jax/_src/core.py#L285
for node in self.jaxpr.eqns:
Expand All @@ -166,7 +167,7 @@ def visit_subgraph(self, node_visitor) -> None:
literal_decoders.append(literal_decoder)
node_visitor(literal_decoder)
decoder = JaxprPythonDecoder(node, name=self.name + "/" + node.primitive.name, literals=literal_decoders)
self.m_decoders.append(decoder)
self.m_decoders.add(decoder)
node_visitor(decoder)

def get_op_type(self) -> str:
Expand Down
Loading