Skip to content

Commit af96aae

Browse files
tf32 fix lds
1 parent b108521 commit af96aae

File tree

1 file changed

+33
-39
lines changed

1 file changed

+33
-39
lines changed

tensilelite/Tensile/Components/F32XEmulation.py

+33-39
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ def __call__(self):
5757
#
5858
# To:
5959
#
60-
# 0: [0low, 0high]
61-
# 1: [1low, 1high]
62-
# 2: [2low, 2high]
63-
# 3: [3low, 3high]
60+
# 0: [0high, 0low]
61+
# 1: [1high, 1low]
62+
# 2: [2high, 2low]
63+
# 3: [3high, 3low]
6464
#
6565
# Carson: cannot do this, as it will break the 4 stride reassembly
6666
# 0: [0high, 1high]
@@ -72,9 +72,10 @@ def __call__(self):
7272

7373
# # high bits
7474
# tf32mod.add(SNop(waitState=1, comment="1 wait states for ds_read"))
75-
# tf32mod.add(VCvtPkF32toBF16(dst=vgpr("Cvt+0"), src0=vgpr("G2LA+0"), src1=vgpr("G2LA+1")))
75+
tf32mod.add(SWaitCnt(lgkmcnt=0, comment="wait for lds read"))
76+
tf32mod.add(VCvtPkF32toBF16(dst=vgpr("Cvt+0"), src0=vgpr("G2LA+0"), src1=vgpr("G2LA+1")))
7677
# tf32mod.add(SNop(waitState=1, comment="1 wait states for ds_read"))
77-
# tf32mod.add(VCvtPkF32toBF16(dst=vgpr("Cvt+1"), src0=vgpr("G2LA+2"), src1=vgpr("G2LA+3")))
78+
tf32mod.add(VCvtPkF32toBF16(dst=vgpr("Cvt+1"), src0=vgpr("G2LA+2"), src1=vgpr("G2LA+3")))
7879
# tf32mod.add(SNop(waitState=1, comment="1 wait states for ds_read"))
7980
# low bits
8081
tf32mod.add(VCvtBF16toFP32(dst="Cvt+8", src="Cvt+0", vgprMask="", vi=0))
@@ -93,10 +94,12 @@ def __call__(self):
9394
tf32mod.add(SNop(waitState=1, comment="1 wait states for ds_read"))
9495
tf32mod.add(VSubF32(dst=vgpr("Cvt+5"), src0=vgpr("G2LA+3"), src1=vgpr("Cvt+9")))
9596
tf32mod.add(SNop(waitState=1, comment="1 wait states for ds_read"))
97+
tf32mod.add(SWaitCnt(lgkmcnt=0, comment="wait for lds read"))
9698
tf32mod.add(VCvtPkF32toBF16(dst=vgpr("G2LA+0"), src0=vgpr("Cvt+2"), src1=vgpr("G2LA+0")))
9799
tf32mod.add(VCvtPkF32toBF16(dst=vgpr("G2LA+1"), src0=vgpr("Cvt+3"), src1=vgpr("G2LA+1")))
98100
tf32mod.add(VCvtPkF32toBF16(dst=vgpr("G2LA+2"), src0=vgpr("Cvt+4"), src1=vgpr("G2LA+2")))
99101
tf32mod.add(VCvtPkF32toBF16(dst=vgpr("G2LA+3"), src0=vgpr("Cvt+5"), src1=vgpr("G2LA+3")))
102+
tf32mod.add(SWaitCnt(lgkmcnt=0, comment="wait for lds read"))
100103

101104
# tf32mod.add(VMovB32(dst=vgpr("G2LA+0"), src=vgpr("Cvt+0")))
102105
# tf32mod.add(SNop(waitState=1, comment="1 wait states for ds_read"))
@@ -129,50 +132,41 @@ def __call__(self, LocalReadX):
129132
tf32mod = Module()
130133
# Carson: textblock here (or in localread) is causing python issues. rocisa ambiguity issue?
131134
tf32mod.add(TextBlock("/*TF32 Emulation read lds*/\n"))
132-
tf32mod.add(TextBlock(str("label_tf32Read_") + str(F32XEmulationCvtLocalRead.dbgCounter) + ":\n"))
133135
tf32mod.add(SWaitCnt(lgkmcnt=0, comment="wait for lds read"))
136+
tf32mod.add(TextBlock(str("label_tf32Read_") + str(F32XEmulationCvtLocalRead.dbgCounter) + ":\n"))
134137
F32XEmulationCvtLocalRead.dbgCounter += 1
135138
tf32mod.add(LocalReadX(dst=vgpr("Cvt+0"), src=vgpr("LocalReadAddrA"), ds=DSModifiers(na=1, offset=0)))
136139
tf32mod.add(LocalReadX(dst=vgpr("Cvt+1"), src=vgpr("LocalReadAddrA"), ds=DSModifiers(na=1, offset=256)))
137140
tf32mod.add(LocalReadX(dst=vgpr("Cvt+2"), src=vgpr("LocalReadAddrA"), ds=DSModifiers(na=1, offset=512)))
138141
tf32mod.add(LocalReadX(dst=vgpr("Cvt+3"), src=vgpr("LocalReadAddrA"), ds=DSModifiers(na=1, offset=768)))
139142
tf32mod.add(SWaitCnt(lgkmcnt=0, comment="wait for lds read"))
140-
#a = A()
141-
#a.instance_patched = types.MethodType(my_method, a)
142-
# a = VMovB32(dst=vgpr("ValuA_X0_I0+0"), src=vgpr("Cvt+0"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_1, src0_sel=SelectBit.WORD_1))
143-
# VMovB32.instance_patched = types.MethodType(my_method, a)
144-
#types.MethodType(issueLatencyOp, op)
145-
#op.issueLatency = types.MethodType(issueLatencyOp, op)
146-
#op.issueLatency()
147-
#tf32mod.add(op)
148-
149143

150-
151-
# tf32mod.add(VMovB32(dst=vgpr("ValuA_X0_I0+0"), src=vgpr("Cvt+0"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_1, src0_sel=SelectBit.WORD_1)))
152-
# tf32mod.add(VMovB32(dst=vgpr("ValuA_X0_I0+1"), src=vgpr("Cvt+2"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_1, src0_sel=SelectBit.WORD_1)))
153-
# tf32mod.add(VMovB32(dst=vgpr("ValuA_X0_I0+2"), src=vgpr("Cvt+0"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_1, src0_sel=SelectBit.WORD_0)))
154-
# tf32mod.add(VMovB32(dst=vgpr("ValuA_X0_I0+3"), src=vgpr("Cvt+2"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_1, src0_sel=SelectBit.WORD_0)))
155-
# tf32mod.add(SWaitCnt(lgkmcnt=0, comment="wait for lds read"))
156-
# tf32mod.add(VMovB32(dst=vgpr("ValuA_X0_I0+0"), src=vgpr("Cvt+1"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_0, src0_sel=SelectBit.WORD_1)))
157-
# tf32mod.add(VMovB32(dst=vgpr("ValuA_X0_I0+1"), src=vgpr("Cvt+3"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_0, src0_sel=SelectBit.WORD_1)))
158-
# tf32mod.add(VMovB32(dst=vgpr("ValuA_X0_I0+2"), src=vgpr("Cvt+1"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_0, src0_sel=SelectBit.WORD_0)))
159-
# tf32mod.add(VMovB32(dst=vgpr("ValuA_X0_I0+3"), src=vgpr("Cvt+3"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_0, src0_sel=SelectBit.WORD_0)))
160-
161-
162-
163-
tf32mod.add(VMovB32(dst=vgpr("ValuA_X0_I0+0"), src=vgpr("Cvt+0"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_0, src0_sel=SelectBit.WORD_0)))
164-
tf32mod.add(VMovB32(dst=vgpr("ValuA_X0_I0+0"), src=vgpr("Cvt+1"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_1, src0_sel=SelectBit.WORD_0)))
165-
tf32mod.add(VMovB32(dst=vgpr("ValuA_X0_I0+1"), src=vgpr("Cvt+2"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_0, src0_sel=SelectBit.WORD_0)))
166-
tf32mod.add(VMovB32(dst=vgpr("ValuA_X0_I0+1"), src=vgpr("Cvt+3"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_1, src0_sel=SelectBit.WORD_0)))
144+
#pack high bits
145+
tf32mod.add(VMovB32(dst=vgpr("ValuA_X0_I0+0"), src=vgpr("Cvt+0"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_0, src0_sel=SelectBit.WORD_1)))
146+
tf32mod.add(SWaitCnt(lgkmcnt=0, comment="wait for lds read"))
147+
tf32mod.add(VMovB32(dst=vgpr("ValuA_X0_I0+0"), src=vgpr("Cvt+1"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_1, src0_sel=SelectBit.WORD_1)))
148+
tf32mod.add(SWaitCnt(lgkmcnt=0, comment="wait for lds read"))
149+
tf32mod.add(VMovB32(dst=vgpr("ValuA_X0_I0+1"), src=vgpr("Cvt+2"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_0, src0_sel=SelectBit.WORD_1)))
150+
tf32mod.add(SWaitCnt(lgkmcnt=0, comment="wait for lds read"))
151+
tf32mod.add(VMovB32(dst=vgpr("ValuA_X0_I0+1"), src=vgpr("Cvt+3"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_1, src0_sel=SelectBit.WORD_1)))
152+
tf32mod.add(SWaitCnt(lgkmcnt=0, comment="wait for lds read"))
153+
#pack low bits
154+
tf32mod.add(VMovB32(dst=vgpr("ValuA_X0_I0+2"), src=vgpr("Cvt+0"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_0, src0_sel=SelectBit.WORD_0)))
155+
tf32mod.add(SWaitCnt(lgkmcnt=0, comment="wait for lds read"))
156+
tf32mod.add(VMovB32(dst=vgpr("ValuA_X0_I0+2"), src=vgpr("Cvt+1"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_1, src0_sel=SelectBit.WORD_0)))
157+
tf32mod.add(SWaitCnt(lgkmcnt=0, comment="wait for lds read"))
158+
tf32mod.add(VMovB32(dst=vgpr("ValuA_X0_I0+3"), src=vgpr("Cvt+2"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_0, src0_sel=SelectBit.WORD_0)))
159+
tf32mod.add(SWaitCnt(lgkmcnt=0, comment="wait for lds read"))
160+
tf32mod.add(VMovB32(dst=vgpr("ValuA_X0_I0+3"), src=vgpr("Cvt+3"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_1, src0_sel=SelectBit.WORD_0)))
167161
tf32mod.add(SWaitCnt(lgkmcnt=0, comment="wait for lds read"))
168-
tf32mod.add(VMovB32(dst=vgpr("ValuA_X0_I0+2"), src=vgpr("Cvt+0"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_0, src0_sel=SelectBit.WORD_1)))
169-
tf32mod.add(VMovB32(dst=vgpr("ValuA_X0_I0+2"), src=vgpr("Cvt+1"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_1, src0_sel=SelectBit.WORD_1)))
170-
tf32mod.add(VMovB32(dst=vgpr("ValuA_X0_I0+3"), src=vgpr("Cvt+2"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_0, src0_sel=SelectBit.WORD_1)))
171-
tf32mod.add(VMovB32(dst=vgpr("ValuA_X0_I0+3"), src=vgpr("Cvt+3"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_1, src0_sel=SelectBit.WORD_1)))
172-
173162

163+
# read format:
164+
# 0: [0high, 1high]
165+
# 1: [2high, 3high]
166+
# 2: [0low, 1low]
167+
# 3: [2low, 3low]
168+
#
174169

175-
tf32mod.add(SWaitCnt(lgkmcnt=0, comment="wait for lds read"))
176170
tf32mod.add(TextBlock(str("label_tf32Read_") + str(F32XEmulationCvtLocalRead.dbgCounter) + ":\n"))
177171
F32XEmulationCvtLocalRead.dbgCounter += 1
178172
return tf32mod

0 commit comments

Comments
 (0)