Skip to content

Commit 150ff1f

Browse files
isurufinducer
authored andcommitted
rename_inames: replace old inames that appear as params in other domains
1 parent 4e49be7 commit 150ff1f

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

loopy/transform/iname.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2327,6 +2327,7 @@ def rename_inames(kernel, old_inames, new_iname, existing_ok=False,
23272327
raise LoopyError(f"iname '{new_iname}' conflicts with an existing identifier"
23282328
" --cannot rename")
23292329

2330+
orig_old_inames = old_inames
23302331
if not does_exist:
23312332
# {{{ rename old_inames[0] -> new_iname
23322333
# so that the code below can focus on "merging" inames that already exist
@@ -2404,6 +2405,16 @@ def does_insn_involve_iname(kernel, insn, *args):
24042405
smap.map_kernel(kernel, within=does_insn_involve_iname,
24052406
map_tvs=False, map_args=False))
24062407

2408+
# replace instances where the old inames appear as a param
2409+
new_domains = []
2410+
for dom in kernel.domains:
2411+
for old_iname in orig_old_inames:
2412+
d = dom.get_var_dict()
2413+
if old_iname in d and new_iname not in d:
2414+
var_type, var_num = d[old_iname]
2415+
dom = dom.set_dim_name(var_type, var_num, new_iname)
2416+
new_domains.append(dom)
2417+
24072418
new_instructions = [insn.copy(within_inames=((insn.within_inames
24082419
- frozenset(old_inames))
24092420
| frozenset([new_iname])))
@@ -2412,7 +2423,7 @@ def does_insn_involve_iname(kernel, insn, *args):
24122423
else insn
24132424
for insn in kernel.instructions]
24142425

2415-
kernel = kernel.copy(instructions=new_instructions)
2426+
kernel = kernel.copy(instructions=new_instructions, domains=new_domains)
24162427

24172428
return kernel
24182429

test/test_transform.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1468,6 +1468,22 @@ def test_rename_inames(ctx_factory):
14681468
lp.auto_test_vs_ref(knl, ctx, ref_knl)
14691469

14701470

1471+
def test_rename_inames_with_params(ctx_factory):
1472+
# https://github.com/inducer/loopy/issues/726
1473+
ctx = ctx_factory()
1474+
1475+
knl = lp.make_kernel(
1476+
[
1477+
"{ [i]: 0<=i<10 }",
1478+
"{ [k]: 0<=k<i }",
1479+
],
1480+
"out[i, k] = 2"
1481+
)
1482+
ref_knl = knl
1483+
knl = lp.rename_inames(knl, ["i"], "j")
1484+
lp.auto_test_vs_ref(knl, ctx, ref_knl)
1485+
1486+
14711487
def test_buffer_array_preserves_rev_deps(ctx_factory):
14721488
# See https://github.com/inducer/loopy/issues/546
14731489
ctx = ctx_factory()

0 commit comments

Comments
 (0)