Skip to content

Commit cdfdbc0

Browse files
Angelogebseanprime7
authored andcommitted
Update 25.12.18
Updates include * Workaround for dlpack lifetime issue - 3a3ec91cd2133fa9ef41fb5220fe10eef7fdc0eb by Anxhelo Xhebraj <axhebraj@nvidia.com> Signed-off-by: Anxhelo Xhebraj <axhebraj@nvidia.com> GitOrigin-RevId: 3a3ec91cd2133fa9ef41fb5220fe10eef7fdc0eb
1 parent 230c77b commit cdfdbc0

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

src/jaxpp/core.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1924,7 +1924,7 @@ def infer_donation(
19241924
last_use = last_used(tasked_jaxpr)
19251925

19261926
invar_is_donated = dict(zip(tasked_jaxpr.invars, donated_invars))
1927-
received_vars = set[jcore.Var]()
1927+
undonateable_vars = set[jcore.Var]()
19281928

19291929
least_donation = dict[tuple[int, TaskType], Sequence[bool]]()
19301930
new_eqns = []
@@ -1935,8 +1935,8 @@ def infer_donation(
19351935
donation = tuple(
19361936
is_last_use_for_invar[invar_idx]
19371937
and invar_is_donated.get(invar, True)
1938-
# NOTE: we avoid donating received invars
1939-
and invar not in received_vars
1938+
# NOTE: we avoid donating sent and received invars
1939+
and invar not in undonateable_vars
19401940
for invar_idx, invar in enumerate(task_eqn.invars)
19411941
)
19421942

@@ -1952,10 +1952,10 @@ def infer_donation(
19521952
least_donation[task_eqn.params["task_info"]] = donation
19531953
elif task_eqn.primitive is transfer_p:
19541954
# NOTE: we avoid donating received invars.
1955-
# Variables that are sent are not donated because
1956-
# send_done (below) extends their lifetime to the end of
1957-
# the program
1958-
received_vars.update(task_eqn.outvars)
1955+
# FIXME(#44): sent invars could be donated however jax_primitives.py impls
1956+
# keep scoped holds on sent invars
1957+
undonateable_vars.update(task_eqn.invars)
1958+
undonateable_vars.update(task_eqn.outvars)
19591959
new_eqns.append(task_eqn)
19601960
elif task_eqn.primitive is add_multi_p:
19611961
new_eqns.append(

0 commit comments

Comments
 (0)