Skip to content

Commit 6a254a4

Browse files
amd-jianli12tensorflower-gardener
authored andcommitted
PR tensorflow#21845: [ROCM] Add missing triton MLIR int4 -> int8 rewrite pass for ROCM
Imported from GitHub PR openxla/xla#21845 ``` TritonTest.DotWithInt4WeightsOnLhsFusedWithMultiplyByChannelScales TritonTest.NonstandardLayoutInt4 TritonTest.DotWithI4WeightsOnLhsWithBitcastTo3dTensor TritonTest.DotWithI4WeightsOnLhsWithNonStandardLayoutAndMultplyInEpilogue TritonTest.LHSWithMinorDimEqualTo1 TritonTest.RHSWithMinorDimEqualTo1 TritonTest.LHSNonMinorContractingDim TritonTest.LHSNonMinorContractingDimWithBatchDim0 TritonTest.LHSMinorContractingDim TritonTest.ConvertPlusNegate TritonTest.LHSMinorContractingDimWithBatchDim0 TritonTest.RHSTestWithNotMinorContractingDim TritonTest.RHSTestWithMinorContractingDim TritonTest.RHSTestWithMinorContractingDimWithBatchDim TritonTest.RHSTestWithNotMinorContractingDimWithBatchDim0 ParametrizedTritonTest.Int4WeightsOnTheLhs ParametrizedTritonTest.Int4WeightsOnTheLhsWithBatchDim ParametrizedTritonTest.Int4WeightsOnTheRhs ``` Tests above are failing on ROCm side after int4 rewriting was moved from legacy matmul emitter to MLIR pass. This MLIR pass is now missing in ROCm triton pipeline and I'm adding it in the place. @xla-rotation: would you please take a look? Copybara import of the project: -- 75e78ad365a9d55f6e299c7b64400447ceebb26d by Jian Li <[email protected]>: [ROCM] Add missing triton MLIR int4 -> int8 rewrite pass for ROCM Merging this change closes tensorflow#21845 PiperOrigin-RevId: 720233927
1 parent d834a72 commit 6a254a4

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

third_party/xla/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ limitations under the License.
2525
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
2626
#include "mlir/Pass/PassManager.h"
2727
#include "mlir/Transforms/Passes.h"
28+
#include "xla/backends/gpu/codegen/triton/xla_triton_passes.h"
2829
#include "xla/service/gpu/llvm_gpu_backend/amdgpu_backend.h"
2930
#include "xla/service/gpu/matmul_utils.h"
3031
#include "xla/service/hlo_module_config.h"
@@ -47,6 +48,7 @@ namespace ma = ::mlir::arith;
4748
namespace mm = ::mlir::math;
4849
namespace ml = ::mlir::LLVM;
4950
namespace mt = ::mlir::triton;
51+
namespace mt_xla = ::mlir::triton::xla;
5052

5153
using ::llvm::SmallVector;
5254
using mlir::ArrayRef;
@@ -64,6 +66,10 @@ absl::Status CreateTritonPipeline(mlir::OpPassManager* pm,
6466
const int threadsPerWarp = 32;
6567
auto cc = se::RocmComputeCapability(std::move(arch_name));
6668

69+
if (is_xla_fusion) {
70+
pm->addPass(mt_xla::CreateInt4ToPackedInt4RewritePass());
71+
}
72+
6773
// Based on make_ttir() in
6874
// @triton//:third_party/amd/backend/compiler.py
6975
pm->addPass(mlir::createInlinerPass());

0 commit comments

Comments
 (0)