@@ -1924,7 +1924,7 @@ def infer_donation(
19241924 last_use = last_used (tasked_jaxpr )
19251925
19261926 invar_is_donated = dict (zip (tasked_jaxpr .invars , donated_invars ))
1927- received_vars = set [jcore .Var ]()
1927+ undonateable_vars = set [jcore .Var ]()
19281928
19291929 least_donation = dict [tuple [int , TaskType ], Sequence [bool ]]()
19301930 new_eqns = []
@@ -1935,8 +1935,8 @@ def infer_donation(
19351935 donation = tuple (
19361936 is_last_use_for_invar [invar_idx ]
19371937 and invar_is_donated .get (invar , True )
1938- # NOTE: we avoid donating received invars
1939- and invar not in received_vars
1938+ # NOTE: we avoid donating sent and received invars
1939+ and invar not in undonateable_vars
19401940 for invar_idx , invar in enumerate (task_eqn .invars )
19411941 )
19421942
@@ -1952,10 +1952,10 @@ def infer_donation(
19521952 least_donation [task_eqn .params ["task_info" ]] = donation
19531953 elif task_eqn .primitive is transfer_p :
19541954 # NOTE: we avoid donating received invars.
1955- # Variables that are sent are not donated because
1956- # send_done (below) extends their lifetime to the end of
1957- # the program
1958- received_vars .update (task_eqn .outvars )
1955+ # FIXME(#44): sent invars could be donated however jax_primitives.py impls
1956+ # keep scoped holds on sent invars
1957+ undonateable_vars . update ( task_eqn . invars )
1958+ undonateable_vars .update (task_eqn .outvars )
19591959 new_eqns .append (task_eqn )
19601960 elif task_eqn .primitive is add_multi_p :
19611961 new_eqns .append (
0 commit comments