|
25 | 25 | from jax._src import dispatch |
26 | 26 | from jax._src import dtypes |
27 | 27 | from jax._src import pretty_printer as pp |
| 28 | +from jax._src import source_info_util |
28 | 29 | from jax._src import traceback_util |
29 | 30 | from jax._src import tree_util |
30 | 31 | from jax._src.interpreters import ad |
@@ -532,42 +533,48 @@ def _get_pp_rule(eqn, context, settings) -> pp.Doc: |
532 | 533 | x, *flat_idx = eqn.invars |
533 | 534 | transforms = tree_util.tree_unflatten(eqn.params["tree"], flat_idx) |
534 | 535 | lhs = core.pp_vars([y], context, print_shapes=settings.print_shapes) |
| 536 | + annotation = (source_info_util.summarize(eqn.source_info) |
| 537 | + if settings.source_info else None) |
535 | 538 | return pp.concat( |
536 | | - [lhs, pp.text(" <- "), pp_ref_transforms(context, x, transforms)] |
| 539 | + [lhs, pp.text(" <- ", annotation=annotation), |
| 540 | + pp_ref_transforms(context, x, transforms)] |
537 | 541 | ) |
538 | 542 | core.pp_eqn_rules[get_p] = _get_pp_rule |
539 | 543 |
|
540 | 544 | def _swap_pp_rule(eqn, context, settings) -> pp.Doc: |
541 | 545 | y, = eqn.outvars |
542 | 546 | x, v, *flat_idx = eqn.invars |
543 | 547 | transforms = tree_util.tree_unflatten(eqn.params["tree"], flat_idx) |
| 548 | + annotation = (source_info_util.summarize(eqn.source_info) |
| 549 | + if settings.source_info else None) |
544 | 550 | if type(y) is core.DropVar: |
545 | 551 | # In the case of a set (ignored return value), |
546 | 552 | # pretty print `_ = swap x v i` as `x[i] <- v` |
547 | 553 | del y |
548 | 554 | return pp.concat([ |
549 | 555 | pp_ref_transforms(context, x, transforms), |
550 | | - pp.text(" <- "), |
| 556 | + pp.text(" <- ", annotation=annotation), |
551 | 557 | pp.text(core.pp_var(v, context)), |
552 | 558 | ]) |
553 | 559 | else: |
554 | 560 | # pretty-print `y:T = swap x v i` as `y:T, x[i] <- x[i], v` |
555 | 561 | x_i = pp_ref_transforms(context, x, transforms) |
556 | 562 | y = core.pp_vars([y], context, print_shapes=settings.print_shapes) |
557 | | - return pp.concat([y, pp.text(', '), x_i, pp.text(' <- '), |
558 | | - x_i, pp.text(', '), |
559 | | - pp.text(core.pp_var(v, context))]) |
| 563 | + return pp.concat([y, pp.text(', '), x_i, |
| 564 | + pp.text(' <- ', annotation=annotation), x_i, |
| 565 | + pp.text(', '), pp.text(core.pp_var(v, context))]) |
560 | 566 | core.pp_eqn_rules[swap_p] = _swap_pp_rule |
561 | 567 |
|
562 | 568 | def _addupdate_pp_rule(eqn, context, settings) -> pp.Doc: |
563 | | - del settings |
564 | 569 | # pretty-print ` = addupdate x i v` as `x[i] += v` |
565 | 570 | () = eqn.outvars |
566 | 571 | x, v, *flat_idx = eqn.invars |
567 | 572 | transforms = tree_util.tree_unflatten(eqn.params["tree"], flat_idx) |
| 573 | + annotation = (source_info_util.summarize(eqn.source_info) |
| 574 | + if settings.source_info else None) |
568 | 575 | return pp.concat([ |
569 | 576 | pp_ref_transforms(context, x, transforms), |
570 | | - pp.text(" += "), |
| 577 | + pp.text(" += ", annotation=annotation), |
571 | 578 | pp.text(core.pp_var(v, context)), |
572 | 579 | ]) |
573 | 580 | core.pp_eqn_rules[addupdate_p] = _addupdate_pp_rule |
|
0 commit comments